1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
17
18 #include <unordered_map>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/types/optional.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/Location.h" // from @llvm-project
32 #include "mlir/IR/Region.h" // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
34 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
35 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
36 #include "tensorflow/compiler/xla/comparison_util.h"
37 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
38 #include "tensorflow/compiler/xla/protobuf_util.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_computation.h"
41 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
42 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
43 #include "tensorflow/compiler/xla/service/hlo_module.h"
44 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
45 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/xla_data.pb.h"
48 #include "tensorflow/core/platform/statusor.h"
49
50 using llvm::APInt;
51 using llvm::makeArrayRef;
52 using mlir::DenseIntElementsAttr;
53 using mlir::NamedAttribute;
54 using mlir::Operation;
55 using mlir::RankedTensorType;
56 using mlir::Type;
57 using mlir::Value;
58 using mlir::func::FuncOp;
59
60 namespace xla {
61
62 namespace {
63
64 constexpr char kShardingAttr[] = "mhlo.sharding";
65
66 // Note: This sanitization function causes an irreversible many-to-one mapping
67 // and any solution to mitigate this would cause issues with the reverse
68 // direction. Longterm solution is to add a function attribute to maintain the
69 // original HLO naming.
SanitizeFunctionName(llvm::StringRef name)70 std::string SanitizeFunctionName(llvm::StringRef name) {
71 std::string output(name);
72 llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
73 return output;
74 }
75
76 // Returns whether the instruction is a default dot operation.
DotIsDefault(const HloInstruction * instruction)77 bool DotIsDefault(const HloInstruction* instruction) {
78 const auto& operands = instruction->operands();
79 // eg. vector[3] dot matrix[3, 2] => [2] not default dot
80 if (operands[0]->shape().rank() < operands[1]->shape().rank()) {
81 return false;
82 }
83 auto dnums = instruction->dot_dimension_numbers();
84 DotDimensionNumbers default_dimension_numbers;
85 default_dimension_numbers.add_lhs_contracting_dimensions(
86 instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
87 default_dimension_numbers.add_rhs_contracting_dimensions(0);
88 return xla::protobuf_util::ProtobufEquals(dnums, default_dimension_numbers);
89 }
90
91 // Returns an MLIR Location generated from HLO Instruction. Uses instruction
92 // metadata if present or instruction name.
GenerateInstructionLocation(const HloInstruction * instruction,mlir::OpBuilder * func_builder)93 mlir::Location GenerateInstructionLocation(const HloInstruction* instruction,
94 mlir::OpBuilder* func_builder) {
95 const std::string& op_name = instruction->metadata().op_name();
96 if (op_name.empty()) {
97 return mlir::NameLoc::get(func_builder->getStringAttr(instruction->name()));
98 }
99
100 mlir::Location op_name_loc =
101 mlir::NameLoc::get(func_builder->getStringAttr(op_name));
102 const std::string& source_file = instruction->metadata().source_file();
103 if (source_file.empty()) {
104 return op_name_loc;
105 }
106
107 return func_builder->getFusedLoc(
108 {op_name_loc,
109 mlir::FileLineColLoc::get(func_builder->getContext(), source_file,
110 instruction->metadata().source_line(), 0)});
111 }
112
113 // Clean up the GetTupleElementOp, created during the flattening of
114 // tuple arguments and return values, if eligible for folding. Removal of
115 // get-tuple-element can transitively make the defining TupleOp dead to be
116 // removed subsequently.
CleanUpTupleOps(mlir::Block * block,mlir::OpBuilder * builder)117 void CleanUpTupleOps(mlir::Block* block, mlir::OpBuilder* builder) {
118 bool changed = true;
119 llvm::SmallVector<Value> folded_results;
120
121 while (changed) {
122 changed = false;
123 for (Operation& op : llvm::make_early_inc_range(block->getOperations())) {
124 if (llvm::isa<mlir::mhlo::GetTupleElementOp>(op)) {
125 folded_results.clear();
126 if (failed(builder->tryFold(&op, folded_results))) continue;
127 op.replaceAllUsesWith(folded_results);
128 op.erase();
129 changed = true;
130 } else if (llvm::isa<mlir::mhlo::TupleOp>(op) &&
131 mlir::isOpTriviallyDead(&op)) {
132 op.erase();
133 changed = true;
134 }
135 }
136 }
137 }
138
139 } // namespace
140
ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation * op,llvm::ArrayRef<mlir::Value> implicit_operands)141 void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands(
142 mlir::Operation* op, llvm::ArrayRef<mlir::Value> implicit_operands) {
143 assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
144 mlir::dyn_cast<mlir::mhlo::CaseOp>(*op)) &&
145 "Unexpected mlir op in "
146 "HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands!");
147
148 int implicit_operand_index = 0;
149 for (auto& region : op->getRegions()) {
150 for (auto arg : region.getArguments()) {
151 assert(implicit_operand_index < implicit_operands.size());
152 arg.replaceAllUsesWith(implicit_operands[implicit_operand_index++]);
153 }
154 region.front().eraseArguments(
155 llvm::to_vector(llvm::seq<unsigned>(0, region.getNumArguments())));
156 }
157 }
158
CreateTupleFromOpResults(mlir::OpBuilder * func_builder,mlir::Location loc,mlir::Operation * op,mlir::Type type)159 mlir::Operation* HloFunctionImporter::CreateTupleFromOpResults(
160 mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Operation* op,
161 mlir::Type type) {
162 if (!type.isa<mlir::TupleType>()) return op;
163
164 llvm::SmallVector<Value> flattened_results = op->getResults();
165 llvm::MutableArrayRef<mlir::Value> flattened_results_ref(flattened_results);
166 auto result =
167 CreateTupleValue(func_builder, loc, flattened_results_ref, type);
168 auto defining_tuple_op = result.getDefiningOp<mlir::mhlo::TupleOp>();
169 assert(defining_tuple_op && "builder didn't return the right type");
170 auto tupleOp = defining_tuple_op.getOperation();
171 return tupleOp;
172 }
173
IsNestedTupleInData(Type type)174 static bool IsNestedTupleInData(Type type) {
175 auto tuple_type = type.dyn_cast<mlir::TupleType>();
176 if (!tuple_type) return false;
177
178 assert(tuple_type.getType(1).isa<mlir::mhlo::TokenType>() &&
179 "Infeed: Non token type");
180 auto data_type = tuple_type.getType(0);
181
182 auto data_tuple_type = data_type.dyn_cast<mlir::TupleType>();
183 if (!data_tuple_type) return false;
184
185 for (auto child_type : data_tuple_type.getTypes()) {
186 if (child_type.isa<mlir::TupleType>()) return true;
187 }
188
189 return false;
190 }
191
FlattenTupleType(Type type,llvm::SmallVectorImpl<Type> & flattened_types)192 void HloFunctionImporter::FlattenTupleType(
193 Type type, llvm::SmallVectorImpl<Type>& flattened_types) {
194 auto tuple_type = type.dyn_cast<mlir::TupleType>();
195 if (!tuple_type) {
196 flattened_types.push_back(type);
197 return;
198 }
199
200 for (auto child_type : tuple_type.getTypes()) {
201 FlattenTupleType(child_type, flattened_types);
202 }
203 }
204
FlattenTupleValue(mlir::OpBuilder * func_builder,mlir::Location loc,Value value,llvm::SmallVectorImpl<Value> & flattened_values)205 void HloFunctionImporter::FlattenTupleValue(
206 mlir::OpBuilder* func_builder, mlir::Location loc, Value value,
207 llvm::SmallVectorImpl<Value>& flattened_values) {
208 auto tuple_type = value.getType().dyn_cast<mlir::TupleType>();
209 if (!tuple_type) {
210 flattened_values.push_back(value);
211 return;
212 }
213
214 int flattenIdx = 0;
215 for (auto child_type : tuple_type.getTypes()) {
216 auto sub_value = func_builder->create<mlir::mhlo::GetTupleElementOp>(
217 loc, child_type, value, func_builder->getI32IntegerAttr(flattenIdx++));
218 FlattenTupleValue(func_builder, loc, sub_value, flattened_values);
219 }
220 }
221
CreateTupleValue(mlir::OpBuilder * func_builder,mlir::Location loc,llvm::MutableArrayRef<Value> & flatten_values,Type type)222 Value HloFunctionImporter::CreateTupleValue(
223 mlir::OpBuilder* func_builder, mlir::Location loc,
224 llvm::MutableArrayRef<Value>& flatten_values, Type type) {
225 auto tuple_type = type.dyn_cast<mlir::TupleType>();
226 if (!tuple_type) {
227 assert(!flatten_values.empty());
228 auto retval = flatten_values.front();
229 flatten_values = flatten_values.drop_front();
230 return retval;
231 }
232
233 llvm::SmallVector<mlir::Value> flatten_sub_values;
234 for (auto child_type : tuple_type.getTypes())
235 flatten_sub_values.push_back(
236 CreateTupleValue(func_builder, loc, flatten_values, child_type));
237
238 return func_builder->create<mlir::mhlo::TupleOp>(loc, flatten_sub_values)
239 .getResult();
240 }
241
ImportAsFunc(const HloComputation & computation,mlir::ModuleOp module,std::unordered_map<const HloComputation *,FuncOp> * function_map,mlir::Builder * builder,bool is_main)242 Status HloFunctionImporter::ImportAsFunc(
243 const HloComputation& computation, mlir::ModuleOp module,
244 std::unordered_map<const HloComputation*, FuncOp>* function_map,
245 mlir::Builder* builder, bool is_main) {
246 HloFunctionImporter importer(module, function_map, builder);
247 return importer.ImportAsFunc(computation, is_main).status();
248 }
249
ImportAsRegion(const xla::HloComputation & computation,mlir::Region * region,mlir::Builder * builder,bool flatten_region_arg_tuple)250 Status HloFunctionImporter::ImportAsRegion(
251 const xla::HloComputation& computation, mlir::Region* region,
252 mlir::Builder* builder, bool flatten_region_arg_tuple) {
253 HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
254 builder);
255 return importer.ImportAsRegion(computation, region, flatten_region_arg_tuple);
256 }
257
ImportAsFunc(const HloComputation & computation,bool is_main)258 StatusOr<FuncOp> HloFunctionImporter::ImportAsFunc(
259 const HloComputation& computation, bool is_main) {
260 std::string computation_name =
261 is_main ? "main" : SanitizeFunctionName(computation.name());
262
263 FuncOp* imported(nullptr);
264 if (function_map_) {
265 imported = &((*function_map_)[&computation]);
266 if (*imported) {
267 return *imported;
268 }
269 } else {
270 TF_RET_CHECK(!module_.lookupSymbol<FuncOp>(computation_name))
271 << "Attempting to redeclare an existing function named "
272 << computation.name();
273 }
274 llvm::SmallVector<Type, 4> args, rets;
275 TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
276 TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
277 auto func_type = mlir::FunctionType::get(context_, args, rets);
278
279 // Construct the MLIR function and map arguments.
280 llvm::ArrayRef<mlir::NamedAttribute> attrs;
281 auto function = FuncOp::create(mlir::UnknownLoc::get(context_),
282 computation_name, func_type, attrs);
283 auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
284 : FuncOp::Visibility::Private;
285 function.setVisibility(visibility);
286
287 for (auto& entry : llvm::enumerate(computation.parameter_instructions())) {
288 HloInstruction* parameter = entry.value();
289 if (parameter->has_sharding()) {
290 function.setArgAttr(
291 entry.index(), kShardingAttr,
292 builder_->getStringAttr(
293 parameter->sharding().ToProto().SerializeAsString()));
294 }
295 }
296 if (computation.root_instruction()->has_sharding()) {
297 auto result = computation.root_instruction();
298 if (function.getNumResults() != 1) {
299 return tensorflow::errors::Internal(absl::StrCat(
300 "Expected only a single result but got ", function.getNumResults()));
301 }
302 function.setResultAttr(
303 0, kShardingAttr,
304 builder_->getStringAttr(
305 result->sharding().ToProto().SerializeAsString()));
306 }
307
308 module_.push_back(function);
309
310 // Add to the map right away for function calls if map is set.
311 if (imported) {
312 *imported = function;
313 }
314
315 mlir::Block* block = function.addEntryBlock();
316 TF_RETURN_IF_ERROR(ImportInstructions(computation, block,
317 /*flatten_region_arg_tuple=*/false));
318
319 return function;
320 }
321
ImportAsRegion(const HloComputation & computation,mlir::Region * region,bool flatten_region_arg_tuple)322 tensorflow::Status HloFunctionImporter::ImportAsRegion(
323 const HloComputation& computation, mlir::Region* region,
324 bool flatten_region_arg_tuple) {
325 auto loc = region->getLoc();
326 // TODO(hinsu): Store computation name as an attribute for round-trip.
327 auto* block = new mlir::Block;
328 region->push_back(block);
329
330 llvm::SmallVector<Type, 4> args;
331 TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
332
333 // Flatten the tuple-typed arguments.
334 if (flatten_region_arg_tuple) {
335 for (auto arg : args) {
336 llvm::SmallVector<Type> flattened_arg_types;
337 FlattenTupleType(arg, flattened_arg_types);
338 block->addArguments(
339 flattened_arg_types,
340 mlir::SmallVector<mlir::Location>(flattened_arg_types.size(), loc));
341 }
342 } else {
343 block->addArguments(args,
344 mlir::SmallVector<mlir::Location>(args.size(), loc));
345 }
346
347 return ImportInstructions(computation, block, flatten_region_arg_tuple);
348 }
349
ImportInstructionsImpl(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)350 StatusOr<Value> HloFunctionImporter::ImportInstructionsImpl(
351 const xla::HloComputation& computation,
352 const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
353 // Setup the input parameters.
354 const int num_parameters = computation.num_parameters();
355
356 for (int i = 0; i < num_parameters; i++) {
357 auto hlo_parameter = computation.parameter_instruction(i);
358 instruction_value_map_[hlo_parameter] = arguments[i];
359 }
360
361 for (auto instruction : computation.MakeInstructionPostOrder()) {
362 TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
363 TF_ASSIGN_OR_RETURN(
364 auto new_operation,
365 ImportInstructionWithLayout(instruction, operands, builder));
366 if (new_operation) {
367 instruction_value_map_[instruction] = new_operation->getResult(0);
368 }
369 }
370
371 // Setup the return type (HLO only supports a single return value).
372 return GetMlirValue(computation.root_instruction());
373 }
374
ImportInstructions(const HloComputation & computation,mlir::Block * block,bool flatten_region_arg_tuple)375 Status HloFunctionImporter::ImportInstructions(
376 const HloComputation& computation, mlir::Block* block,
377 bool flatten_region_arg_tuple) {
378 llvm::SmallVector<Value, 4> arguments(block->args_begin(), block->args_end());
379 mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
380
381 // TODO(suderman): Add location tracking details.
382 mlir::Location loc = builder.getUnknownLoc();
383
384 Value result;
385 if (!llvm::isa<FuncOp>(block->getParentOp()) && flatten_region_arg_tuple) {
386 // 'effective_arguments' stores the mhlo value corresponding to each
387 // computation parameter. The value could be a BlockArgument, if the
388 // corresponding computation parameter is non-tuple typed, or a TupleOp,
389 // otherwise.
390 llvm::SmallVector<Value> effective_arguments;
391
392 llvm::SmallVector<Type> computation_arg_types;
393 TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(),
394 &computation_arg_types));
395 int flatten_idx = 0;
396 for (Type computation_arg_type : computation_arg_types) {
397 auto orig_tuple_arg_type =
398 computation_arg_type.dyn_cast<mlir::TupleType>();
399
400 // If the computation-parameter type is non-tuple, no action is needed.
401 if (!orig_tuple_arg_type) {
402 effective_arguments.push_back(arguments[flatten_idx]);
403 flatten_idx++;
404 continue;
405 }
406
407 // For each tuple-typed computation parameter, create a mhlo::TupleOp
408 // value in the region body, using the already flattened values in
409 // 'arguments'. For example: With computation parameters: [tuple<T1>,
410 // tuple<T2, T4>] We have, 'arguments' = [T1 arg1, T2 arg2, T3 arg3] and
411 // we need to create two tuples tuples, one using arg1, and the other
412 // using arg2 and arg3.
413 llvm::SmallVector<Type> flattened_arg_type;
414 FlattenTupleType(orig_tuple_arg_type, flattened_arg_type);
415
416 llvm::MutableArrayRef<Value> sub_args(
417 arguments.begin() + flatten_idx,
418 arguments.begin() + flatten_idx + flattened_arg_type.size());
419
420 auto tupleVal =
421 CreateTupleValue(&builder, loc, sub_args, orig_tuple_arg_type);
422 effective_arguments.push_back(tupleVal);
423
424 flatten_idx += flattened_arg_type.size();
425 }
426
427 TF_ASSIGN_OR_RETURN(
428 result,
429 ImportInstructionsImpl(computation, effective_arguments, &builder));
430 } else {
431 TF_ASSIGN_OR_RETURN(
432 result, ImportInstructionsImpl(computation, arguments, &builder));
433 }
434
435 // Create terminator op depending on the parent op of this region.
436 if (llvm::isa<FuncOp>(block->getParentOp())) {
437 builder.create<mlir::func::ReturnOp>(loc, result);
438 } else {
439 if (flatten_region_arg_tuple) {
440 // Flatten tuples in results of this region.
441 llvm::SmallVector<Value> flattened_return_operands;
442 FlattenTupleValue(&builder, loc, result, flattened_return_operands);
443 builder.create<mlir::mhlo::ReturnOp>(loc, flattened_return_operands);
444 } else {
445 builder.create<mlir::mhlo::ReturnOp>(loc, result);
446 }
447 }
448
449 CleanUpTupleOps(block, &builder);
450
451 return ::tensorflow::OkStatus();
452 }
453
ImportInstructions(const xla::HloComputation & computation,const llvm::SmallVectorImpl<Value> & arguments,mlir::OpBuilder * builder)454 StatusOr<Value> HloFunctionImporter::ImportInstructions(
455 const xla::HloComputation& computation,
456 const llvm::SmallVectorImpl<Value>& arguments, mlir::OpBuilder* builder) {
457 mlir::Block* block = builder->getBlock();
458 if (block == nullptr)
459 return InvalidArgument(
460 "ImportInstructions requires a valid block in the builder");
461
462 HloFunctionImporter importer(
463 block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
464 return importer.ImportInstructionsImpl(computation, arguments, builder);
465 }
466
ImportInstruction(const xla::HloInstruction * instr,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * builder,DynamicShapeHandlingMode mode)467 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
468 const xla::HloInstruction* instr,
469 const llvm::SmallVectorImpl<mlir::Value>& operands,
470 mlir::OpBuilder* builder, DynamicShapeHandlingMode mode) {
471 mlir::Block* block = builder->getBlock();
472 if (block == nullptr)
473 return InvalidArgument(
474 "ImportInstructions requires a valid block in the builder");
475
476 HloFunctionImporter importer(
477 block->getParent()->getParentOfType<mlir::ModuleOp>(), {}, builder);
478
479 return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
480 }
481
ImportInstructionImpl(const HloInstruction * instruction,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * func_builder,DynamicShapeHandlingMode mode)482 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
483 const HloInstruction* instruction,
484 const llvm::SmallVectorImpl<mlir::Value>& operands,
485 mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
486 const Shape& instruction_shape = instruction->shape();
487 const Shape& shape = mode == DynamicShapeHandlingMode::kConvertToStatic
488 ? xla::ShapeUtil::MakeStaticShape(instruction_shape)
489 : instruction_shape;
490 TF_ASSIGN_OR_RETURN(auto result_type,
491 ConvertShapeToType<RankedTensorType>(shape, *builder_));
492 mlir::Location loc = GenerateInstructionLocation(instruction, func_builder);
493
494 llvm::SmallVector<NamedAttribute, 10> attributes;
495 if (instruction->has_sharding()) {
496 attributes.push_back(builder_->getNamedAttr(
497 kShardingAttr,
498 builder_->getStringAttr(
499 instruction->sharding().ToProto().SerializeAsString())));
500 }
501
502 switch (instruction->opcode()) {
503 case HloOpcode::kParameter: {
504 return nullptr;
505 }
506 case HloOpcode::kConstant: {
507 const Literal& literal = instruction->literal();
508 auto attr = CreateDenseElementsAttrFromLiteral(literal, *builder_);
509 if (!attr.ok()) return attr.status();
510 mlir::Operation* new_operation =
511 func_builder->create<mlir::mhlo::ConstantOp>(loc, attr.ValueOrDie());
512 for (auto attr : attributes) {
513 new_operation->setAttr(attr.getName(), attr.getValue());
514 }
515 return new_operation;
516 }
517 case HloOpcode::kIota: {
518 return func_builder
519 ->create<mlir::mhlo::IotaOp>(
520 loc, result_type,
521 func_builder->getI64IntegerAttr(
522 Cast<HloIotaInstruction>(instruction)->iota_dimension()))
523 .getOperation();
524 }
525 case HloOpcode::kBroadcast: {
526 // Note that the HLO broadcast is more powerful than the XLA broadcast
527 // op. BroadcastInDim offers a superset of the HLO op's functionality.
528 attributes.push_back(
529 builder_->getNamedAttr("broadcast_dimensions",
530 ConvertDimensions(instruction->dimensions())));
531 return func_builder
532 ->create<mlir::mhlo::BroadcastInDimOp>(loc, result_type, operands,
533 attributes)
534 .getOperation();
535 }
536
537 case HloOpcode::kBatchNormGrad:
538 case HloOpcode::kBatchNormInference:
539 case HloOpcode::kBatchNormTraining:
540 attributes.push_back(builder_->getNamedAttr(
541 "epsilon", builder_->getF32FloatAttr(instruction->epsilon())));
542 attributes.push_back(builder_->getNamedAttr(
543 "feature_index",
544 builder_->getI64IntegerAttr(instruction->feature_index())));
545 if (instruction->opcode() == HloOpcode::kBatchNormGrad) {
546 // Flatten the return type if they are tuple-typed.
547 llvm::SmallVector<Type> flattened_ret_types;
548 FlattenTupleType(result_type, flattened_ret_types);
549
550 auto op = func_builder
551 ->create<mlir::mhlo::BatchNormGradOp>(
552 loc, flattened_ret_types, operands, attributes)
553 .getOperation();
554
555 return CreateTupleFromOpResults(func_builder, loc, op, result_type);
556 } else if (instruction->opcode() == HloOpcode::kBatchNormInference) {
557 return func_builder
558 ->create<mlir::mhlo::BatchNormInferenceOp>(loc, result_type,
559 operands, attributes)
560 .getOperation();
561 } else {
562 assert(instruction->opcode() == HloOpcode::kBatchNormTraining);
563
564 // Flatten the return type if they are tuple-typed.
565 llvm::SmallVector<Type> flattened_ret_types;
566 FlattenTupleType(result_type, flattened_ret_types);
567
568 auto op = func_builder
569 ->create<mlir::mhlo::BatchNormTrainingOp>(
570 loc, flattened_ret_types, operands, attributes)
571 .getOperation();
572
573 return CreateTupleFromOpResults(func_builder, loc, op, result_type);
574 }
575
576 case HloOpcode::kDot: {
577 attributes.push_back(builder_->getNamedAttr(
578 "precision_config",
579 ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
580
581 // Consider consolidating DotOps together.
582 if (DotIsDefault(instruction)) {
583 return func_builder
584 ->create<mlir::mhlo::DotOp>(loc, result_type, operands, attributes)
585 .getOperation();
586 }
587
588 attributes.push_back(builder_->getNamedAttr(
589 "dot_dimension_numbers",
590 ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
591 builder_)));
592 return func_builder
593 ->create<mlir::mhlo::DotGeneralOp>(loc, result_type, operands,
594 attributes)
595 .getOperation();
596 }
597 case HloOpcode::kCall: {
598 TF_ASSIGN_OR_RETURN(
599 FuncOp function,
600 ImportAsFunc(*instruction->to_apply(), /*is_main=*/false));
601 mlir::Operation* new_operation =
602 func_builder->create<mlir::func::CallOp>(loc, function, operands);
603 return new_operation;
604 }
605 case HloOpcode::kCollectivePermute: {
606 attributes.push_back(ConvertSourceTargetPairs(
607 instruction->source_target_pairs(), builder_));
608 return func_builder
609 ->create<mlir::mhlo::CollectivePermuteOp>(loc, result_type, operands,
610 attributes)
611 .getOperation();
612 }
613 case HloOpcode::kCustomCall: {
614 auto custom_call = Cast<HloCustomCallInstruction>(instruction);
615 const auto& called_computations = custom_call->called_computations();
616 if (!called_computations.empty()) {
617 llvm::SmallVector<mlir::Attribute> callees;
618 callees.reserve(called_computations.size());
619 for (HloComputation* callee : called_computations) {
620 TF_ASSIGN_OR_RETURN(FuncOp function, ImportAsFunc(*callee,
621 /*is_main=*/false));
622 callees.push_back(mlir::FlatSymbolRefAttr::get(builder_->getContext(),
623 function.getName()));
624 }
625 attributes.push_back(builder_->getNamedAttr(
626 "called_computations",
627 mlir::ArrayAttr::get(builder_->getContext(), callees)));
628 }
629 if (custom_call->layout_constrained()) {
630 TF_ASSIGN_OR_RETURN(
631 mlir::ArrayAttr operand_layouts,
632 ExtractLayoutsFromShapes(custom_call->operand_shapes_with_layout(),
633 builder_));
634 attributes.push_back(
635 builder_->getNamedAttr("operand_layouts", operand_layouts));
636 mlir::ArrayAttr result_layouts;
637 if (custom_call->shape().IsTuple()) {
638 TF_ASSIGN_OR_RETURN(
639 result_layouts,
640 ExtractLayoutsFromTuple(custom_call->shape(), builder_));
641 } else {
642 TF_ASSIGN_OR_RETURN(
643 result_layouts,
644 ExtractLayoutsFromShapes({custom_call->shape()}, builder_));
645 }
646 attributes.push_back(
647 builder_->getNamedAttr("result_layouts", result_layouts));
648 }
649
650 TF_ASSIGN_OR_RETURN(
651 auto mlir_api_version,
652 ConvertCustomCallApiVersion(custom_call->api_version()));
653 attributes.push_back(builder_->getNamedAttr(
654 "call_target_name",
655 builder_->getStringAttr(custom_call->custom_call_target())));
656 attributes.push_back(builder_->getNamedAttr(
657 "has_side_effect",
658 builder_->getBoolAttr(custom_call->custom_call_has_side_effect())));
659 attributes.push_back(builder_->getNamedAttr(
660 "backend_config",
661 builder_->getStringAttr(custom_call->raw_backend_config_string())));
662 attributes.push_back(builder_->getNamedAttr(
663 "api_version", mlir::mhlo::CustomCallApiVersionAttr::get(
664 builder_->getContext(), mlir_api_version)));
665 return func_builder
666 ->create<mlir::mhlo::CustomCallOp>(loc, result_type, operands,
667 attributes)
668 .getOperation();
669 }
670 case HloOpcode::kCompare: {
671 auto compare = Cast<HloCompareInstruction>(instruction);
672 attributes.push_back(ConvertComparisonDirection(compare->direction()));
673 auto default_type = Comparison::DefaultComparisonType(
674 compare->operand(0)->shape().element_type());
675 if (compare->type() != default_type)
676 attributes.push_back(ConvertComparisonType(compare->type()));
677 return func_builder
678 ->create<mlir::mhlo::CompareOp>(loc, result_type, operands,
679 attributes)
680 .getOperation();
681 }
682 case HloOpcode::kCholesky: {
683 attributes.push_back(builder_->getNamedAttr(
684 "lower",
685 builder_->getBoolAttr(instruction->cholesky_options().lower())));
686 return func_builder
687 ->create<mlir::mhlo::CholeskyOp>(loc, result_type, operands,
688 attributes)
689 .getOperation();
690 }
691 case HloOpcode::kGather: {
692 auto gather_instruction = Cast<HloGatherInstruction>(instruction);
693 attributes.push_back(builder_->getNamedAttr(
694 "dimension_numbers",
695 ConvertGatherDimensionNumbers(
696 gather_instruction->gather_dimension_numbers(), builder_)));
697
698 std::vector<int64_t> slice_sizes(
699 gather_instruction->gather_slice_sizes().begin(),
700 gather_instruction->gather_slice_sizes().end());
701 attributes.push_back(
702 builder_->getNamedAttr("slice_sizes", Convert(slice_sizes)));
703 attributes.push_back(builder_->getNamedAttr(
704 "indices_are_sorted",
705 builder_->getBoolAttr(gather_instruction->indices_are_sorted())));
706
707 return func_builder
708 ->create<mlir::mhlo::GatherOp>(loc, result_type, operands, attributes)
709 .getOperation();
710 }
711 case HloOpcode::kDynamicSlice: {
712 std::vector<int64_t> slice_sizes(
713 instruction->dynamic_slice_sizes().begin(),
714 instruction->dynamic_slice_sizes().end());
715 return func_builder
716 ->create<mlir::mhlo::DynamicSliceOp>(
717 loc, result_type, operands[0],
718 makeArrayRef(operands).drop_front(), Convert(slice_sizes))
719 .getOperation();
720 }
721 case HloOpcode::kDynamicUpdateSlice: {
722 return func_builder
723 ->create<mlir::mhlo::DynamicUpdateSliceOp>(
724 loc, result_type, operands[0], operands[1],
725 llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
726 .getOperation();
727 }
728 case HloOpcode::kInfeed: {
729 if (IsNestedTupleInData(result_type)) {
730 llvm_unreachable(
731 "Importing xla::kInfeed with nested tuple shape not supported");
732 }
733
734 attributes.push_back(builder_->getNamedAttr(
735 "infeed_config",
736 mlir::StringAttr::get(builder_->getContext(),
737 instruction->infeed_config())));
738
739 llvm::SmallVector<mlir::Attribute> flattened_attr;
740 TF_RETURN_IF_ERROR(
741 ConvertShapeToMlirLayout(instruction->shape(), flattened_attr));
742 attributes.push_back(builder_->getNamedAttr(
743 "layout", builder_->getArrayAttr(makeArrayRef(flattened_attr))));
744
745 // Flatten the return-type if they are tuple-typed.
746 llvm::SmallVector<Type> flattened_ret_types;
747 FlattenTupleType(result_type, flattened_ret_types);
748
749 auto op = func_builder->create<mlir::mhlo::InfeedOp>(
750 loc, flattened_ret_types, operands, attributes);
751
752 return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
753 result_type);
754 }
755 case HloOpcode::kOutfeed: {
756 attributes.push_back(builder_->getNamedAttr(
757 "outfeed_config",
758 mlir::StringAttr::get(builder_->getContext(),
759 instruction->outfeed_config())));
760
761 assert(operands.size() == 2 && "Expected 2 operands for HLO Infeed");
762
763 // In case operands[0] is a tuple, flatten it.
764 llvm::SmallVector<Value> flattened_operands;
765 FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
766 flattened_operands.push_back(operands[1]);
767
768 auto op = func_builder->create<mlir::mhlo::OutfeedOp>(
769 loc, result_type, flattened_operands, attributes);
770
771 return op.getOperation();
772 }
773 case HloOpcode::kPad: {
774 const auto& padding_config = instruction->padding_config();
775 llvm::SmallVector<int64_t, 4> edge_padding_low;
776 llvm::SmallVector<int64_t, 4> edge_padding_high;
777 llvm::SmallVector<int64_t, 4> interior_padding;
778 edge_padding_low.reserve(padding_config.dimensions_size());
779 edge_padding_high.reserve(padding_config.dimensions_size());
780 interior_padding.reserve(padding_config.dimensions_size());
781
782 for (const auto& dimension : padding_config.dimensions()) {
783 edge_padding_low.push_back(dimension.edge_padding_low());
784 edge_padding_high.push_back(dimension.edge_padding_high());
785 interior_padding.push_back(dimension.interior_padding());
786 }
787
788 return func_builder
789 ->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
790 operands[1], Convert(edge_padding_low),
791 Convert(edge_padding_high),
792 Convert(interior_padding))
793 .getOperation();
794 }
795 case HloOpcode::kScatter: {
796 auto scatter = Cast<HloScatterInstruction>(instruction);
797 attributes.push_back(builder_->getNamedAttr(
798 "scatter_dimension_numbers",
799 ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
800 builder_)));
801 attributes.push_back(builder_->getNamedAttr(
802 "indices_are_sorted",
803 builder_->getBoolAttr(scatter->indices_are_sorted())));
804 attributes.push_back(builder_->getNamedAttr(
805 "unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
806
807 llvm::SmallVector<Type> flattened_types;
808 FlattenTupleType(result_type, flattened_types);
809
810 auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
811 loc, flattened_types, operands, attributes);
812 TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
813 &scatter_op.update_computation(),
814 /*flatten_region_arg_tuple=*/true));
815 TF_ASSIGN_OR_RETURN(auto result_type,
816 ConvertShapeToType<RankedTensorType>(
817 instruction->shape(), *builder_));
818 return CreateTupleFromOpResults(func_builder, loc,
819 scatter_op.getOperation(), result_type);
820 }
821 case HloOpcode::kSelectAndScatter: {
822 auto select_scatter = Cast<HloSelectAndScatterInstruction>(instruction);
823 llvm::SmallVector<int64_t, 4> window_strides, window_dimensions;
824 llvm::SmallVector<int64_t, 8> padding;
825 for (const auto& dim : select_scatter->window().dimensions()) {
826 window_strides.push_back(dim.stride());
827 window_dimensions.push_back(dim.size());
828 padding.push_back(dim.padding_low());
829 padding.push_back(dim.padding_high());
830 }
831 attributes.push_back(
832 builder_->getNamedAttr("window_strides", Convert(window_strides)));
833 attributes.push_back(builder_->getNamedAttr("window_dimensions",
834 Convert(window_dimensions)));
835 attributes.push_back(ConvertPadding(padding));
836 auto select_scatter_op =
837 func_builder->create<mlir::mhlo::SelectAndScatterOp>(
838 loc, result_type, operands, attributes);
839 TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
840 &select_scatter_op.select(),
841 /*flatten_region_arg_tuple=*/true));
842 TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
843 &select_scatter_op.scatter(),
844 /*flatten_region_arg_tuple=*/true));
845 return select_scatter_op.getOperation();
846 }
847 case HloOpcode::kSetDimensionSize: {
848 attributes.push_back(builder_->getNamedAttr(
849 "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
850 return func_builder
851 ->create<mlir::mhlo::SetDimensionSizeOp>(loc, result_type, operands,
852 attributes)
853 .getOperation();
854 }
855 case HloOpcode::kSlice: {
856 return func_builder
857 ->create<mlir::mhlo::SliceOp>(
858 loc, result_type, operands[0],
859 ConvertDimensions(instruction->slice_starts()),
860 ConvertDimensions(instruction->slice_limits()),
861 ConvertDimensions(instruction->slice_strides()))
862 .getOperation();
863 }
864 case HloOpcode::kSort: {
865 auto sort_instruction = Cast<HloSortInstruction>(instruction);
866
867 llvm::SmallVector<Type, 4> return_types = {result_type};
868 if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
869 return_types = llvm::to_vector<6>(tuple_ty.getTypes());
870 }
871
872 auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
873 loc, return_types, operands,
874 builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
875 builder_->getBoolAttr(sort_instruction->is_stable()));
876 TF_RETURN_IF_ERROR(ImportAsRegion(*sort_instruction->to_apply(),
877 &sort_op.comparator(),
878 /*flatten_region_arg_tuple=*/true));
879
880 // Check if the output needs to be tupled.
881 if (return_types.size() == 1 && return_types.front() == result_type) {
882 return sort_op.getOperation();
883 }
884
885 return func_builder
886 ->create<mlir::mhlo::TupleOp>(loc, result_type, sort_op.getResults())
887 .getOperation();
888 }
889 case HloOpcode::kConditional: {
890 llvm::SmallVector<Type, 4> rets;
891
892 // Flatten the tuple-typed operands.
893 llvm::SmallVector<Value> flattened_operands;
894 for (auto& operand : operands)
895 FlattenTupleValue(func_builder, loc, operand, flattened_operands);
896
897 // If/Case Op has a single operand; we collect the other operands to
898 // replace the corresponding block arguments.
899 llvm::ArrayRef<Value> implicit_operands(flattened_operands.begin() + 1,
900 flattened_operands.end());
901
902 mlir::Type pred_or_index_type =
903 operands[0].getType().cast<mlir::TensorType>().getElementType();
904 // It is a predicated conditional if first argument is a boolean and
905 // should be mapped to If op.
906 if (pred_or_index_type.isInteger(1)) {
907 TF_RETURN_IF_ERROR(GetMlirTypes(
908 {instruction->true_computation()->root_instruction()}, &rets));
909
910 // Flatten the return-type.
911 llvm::SmallVector<Type> flattened_ret_types;
912 assert(rets.size() == 1);
913 FlattenTupleType(rets[0], flattened_ret_types);
914
915 auto op = func_builder->create<mlir::mhlo::IfOp>(
916 loc, flattened_ret_types, flattened_operands[0], attributes);
917 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
918 &op.true_branch(),
919 /*flatten_region_arg_tuple=*/true));
920 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
921 &op.false_branch(),
922 /*flatten_region_arg_tuple=*/true));
923
924 // Replace the uses of block-arguments of the IfOp with the
925 // implicit_operands.
926 ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
927 implicit_operands);
928
929 return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
930 rets[0]);
931 }
932
933 // Otherwise, it is a indexed conditional and should be mapped to Case
934 // op.
935 TF_RETURN_IF_ERROR(GetMlirTypes(
936 {instruction->branch_computation(0)->root_instruction()}, &rets));
937
938 // Flatten the return-type.
939 llvm::SmallVector<Type> flattened_ret_types;
940 assert(rets.size() == 1);
941 FlattenTupleType(rets[0], flattened_ret_types);
942
943 int num_branches = instruction->branch_count();
944 auto op = func_builder->create<mlir::mhlo::CaseOp>(
945 loc, flattened_ret_types, flattened_operands[0], attributes,
946 num_branches);
947 for (const auto& index_and_computation :
948 llvm::enumerate(instruction->branch_computations())) {
949 auto index = index_and_computation.index();
950 HloComputation* computation = index_and_computation.value();
951 TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index],
952 /*flatten_region_arg_tuple=*/true));
953 }
954
955 // Replace the uses of block-arguments of the CaseOp with the
956 // implicit_operands.
957 ReplaceBlockArgumentsWithImplicitOperands(op.getOperation(),
958 implicit_operands);
959
960 return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
961 rets[0]);
962 }
963 case HloOpcode::kConcatenate: {
964 // TODO(b/132057942): Support taking an uint64_t instead of an
965 // IntegerAttr for concatenate dimension.
966 return func_builder
967 ->create<mlir::mhlo::ConcatenateOp>(
968 loc, result_type, operands,
969 builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
970 .getOperation();
971 }
972 case HloOpcode::kAllGather: {
973 auto all_gather = Cast<HloAllGatherInstruction>(instruction);
974 attributes.push_back(builder_->getNamedAttr(
975 "all_gather_dim",
976 builder_->getI64IntegerAttr(all_gather->all_gather_dimension())));
977 attributes.push_back(
978 ConvertReplicaGroups(all_gather->replica_groups(), builder_));
979 if (all_gather->channel_id().has_value())
980 attributes.push_back(
981 ConvertChannelHandle(all_gather->channel_id().value()));
982 return func_builder
983 ->create<mlir::mhlo::AllGatherOp>(loc, result_type, operands,
984 attributes)
985 .getOperation();
986 }
987 case HloOpcode::kAllReduce: {
988 auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
989 attributes.push_back(
990 ConvertReplicaGroups(all_reduce->replica_groups(), builder_));
991 if (all_reduce->channel_id().has_value())
992 attributes.push_back(
993 ConvertChannelHandle(all_reduce->channel_id().value()));
994 if (all_reduce->use_global_device_ids())
995 attributes.push_back(ConvertUseGlobalDeviceIds());
996 auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
997 loc, result_type, operands, attributes);
998 TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
999 &all_reduce_op.computation()));
1000 return all_reduce_op.getOperation();
1001 }
1002 case HloOpcode::kAllToAll: {
1003 // TODO(b/207152612): all-to-all HLO can either have pre-split operands
1004 // (and returns a tuple) or a single operand that is split across
1005 // `split_dimension` into the number of replicas in a group. Only the
1006 // latter case (array all-to-all) is supported in importer right now and
1007 // the former (tuple all-to-all) is not supported yet.
1008 auto all_to_all = Cast<HloAllToAllInstruction>(instruction);
1009 if (all_to_all->shape().IsTuple())
1010 return tensorflow::errors::Unimplemented(
1011 "Importing tuple all-to-all HLO is not supported yet");
1012
1013 // Check invariants of array all-to-all. This is a sanity check and is
1014 // verified by the HLO verifier.
1015 if (!all_to_all->split_dimension().has_value() || operands.size() != 1 ||
1016 all_to_all->replica_groups().empty())
1017 return tensorflow::errors::InvalidArgument(
1018 "Array all-to-all should have a split dimension, one operand and "
1019 "non-empty replica groups");
1020
1021 auto replica_groups_attr =
1022 ConvertReplicaGroups(all_to_all->replica_groups(), builder_)
1023 .getValue()
1024 .cast<DenseIntElementsAttr>();
1025 uint64_t split_dim = all_to_all->split_dimension().value();
1026 uint64_t concat_dim = split_dim;
1027 uint64_t split_count = all_to_all->replica_groups()[0].replica_ids_size();
1028
1029 return func_builder
1030 ->create<mlir::mhlo::AllToAllOp>(loc, result_type, operands[0],
1031 split_dim, concat_dim, split_count,
1032 replica_groups_attr)
1033 .getOperation();
1034 }
1035 case HloOpcode::kReduce: {
1036 // Operands in the first half are reduction inputs and the remaining
1037 // operands are corresponding initial values.
1038 size_t num_inputs = operands.size() / 2;
1039 llvm::SmallVector<Type, 4> return_types = {result_type};
1040 if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
1041 return_types = llvm::to_vector<6>(tuple_ty.getTypes());
1042 }
1043
1044 auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
1045 loc, return_types,
1046 llvm::makeArrayRef(operands).take_front(num_inputs),
1047 llvm::makeArrayRef(operands).drop_front(num_inputs),
1048 ConvertDimensions(instruction->dimensions()));
1049 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
1050 &reduce.body(),
1051 /*flatten_region_arg_tuple=*/true));
1052
1053 // Check if the output needs to be tupled.
1054 if (return_types.size() == 1 && return_types.front() == result_type) {
1055 return reduce.getOperation();
1056 }
1057
1058 return func_builder
1059 ->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
1060 .getOperation();
1061 }
1062 case HloOpcode::kReverse: {
1063 return func_builder
1064 ->create<mlir::mhlo::ReverseOp>(
1065 loc, result_type, operands[0],
1066 ConvertDimensions(instruction->dimensions()))
1067 .getOperation();
1068 }
1069 case HloOpcode::kRng: {
1070 auto shape = func_builder->create<mlir::mhlo::ConstantOp>(
1071 loc, Convert(result_type.cast<RankedTensorType>().getShape()));
1072 switch (instruction->random_distribution()) {
1073 case xla::RNG_UNIFORM:
1074 return func_builder
1075 ->create<mlir::mhlo::RngOp>(
1076 loc, result_type, operands[0], operands[1], shape,
1077 ::mlir::mhlo::RngDistribution::UNIFORM)
1078 .getOperation();
1079
1080 case xla::RNG_NORMAL:
1081 return func_builder
1082 ->create<mlir::mhlo::RngOp>(loc, result_type, operands[0],
1083 operands[1], shape,
1084 ::mlir::mhlo::RngDistribution::NORMAL)
1085 .getOperation();
1086
1087 default:
1088 return tensorflow::errors::InvalidArgument(absl::StrCat(
1089 "Unsupported distribution: ",
1090 RandomDistributionToString(instruction->random_distribution())));
1091 }
1092 }
1093 case HloOpcode::kRngBitGenerator: {
1094 auto rng_op = Cast<HloRngBitGeneratorInstruction>(instruction);
1095
1096 // Flatten the return type if they are tuple-typed.
1097 llvm::SmallVector<Type> flattened_ret_types;
1098 FlattenTupleType(result_type, flattened_ret_types);
1099
1100 auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
1101 builder_->getContext(),
1102 *mlir::mhlo::symbolizeRngAlgorithm(rng_op->algorithm()));
1103 auto op = func_builder->create<mlir::mhlo::RngBitGeneratorOp>(
1104 loc, flattened_ret_types, algorithm_attr, operands[0]);
1105
1106 return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
1107 result_type);
1108 }
1109 case HloOpcode::kRngGetAndUpdateState: {
1110 return func_builder
1111 ->create<mlir::mhlo::XlaRngGetAndUpdateStateOp>(
1112 loc, result_type,
1113 func_builder->getI64IntegerAttr(
1114 Cast<HloRngGetAndUpdateStateInstruction>(instruction)
1115 ->delta()))
1116 .getOperation();
1117 }
1118 case HloOpcode::kWhile: {
1119 llvm::SmallVector<Value> flattened_operands;
1120 llvm::SmallVector<Type> flattened_operand_types;
1121 FlattenTupleType(operands[0].getType(), flattened_operand_types);
1122 FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
1123
1124 auto op = func_builder->create<mlir::mhlo::WhileOp>(
1125 loc, flattened_operand_types, flattened_operands);
1126
1127 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_condition(),
1128 &op.cond(),
1129 /*flatten_region_arg_tuple=*/true));
1130 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->while_body(), &op.body(),
1131 /*flatten_region_arg_tuple=*/true));
1132 return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
1133 operands[0].getType());
1134 }
1135 case HloOpcode::kGetTupleElement: {
1136 attributes.push_back(builder_->getNamedAttr(
1137 "index", builder_->getIntegerAttr(builder_->getIntegerType(32),
1138 instruction->tuple_index())));
1139 return func_builder
1140 ->create<mlir::mhlo::GetTupleElementOp>(loc, result_type, operands,
1141 attributes)
1142 .getOperation();
1143 };
1144 case HloOpcode::kGetDimensionSize: {
1145 attributes.push_back(builder_->getNamedAttr(
1146 "dimension", builder_->getI64IntegerAttr(instruction->dimension())));
1147 return func_builder
1148 ->create<mlir::mhlo::GetDimensionSizeOp>(loc, result_type, operands,
1149 attributes)
1150 .getOperation();
1151 };
1152 case HloOpcode::kTranspose: {
1153 attributes.push_back(builder_->getNamedAttr(
1154 "permutation", ConvertDimensions(instruction->dimensions())));
1155 return func_builder
1156 ->create<mlir::mhlo::TransposeOp>(loc, result_type, operands,
1157 attributes)
1158 .getOperation();
1159 }
1160 case HloOpcode::kTriangularSolve: {
1161 attributes.push_back(builder_->getNamedAttr(
1162 "left_side",
1163 builder_->getBoolAttr(
1164 instruction->triangular_solve_options().left_side())));
1165 attributes.push_back(builder_->getNamedAttr(
1166 "lower", builder_->getBoolAttr(
1167 instruction->triangular_solve_options().lower())));
1168 attributes.push_back(builder_->getNamedAttr(
1169 "unit_diagonal",
1170 builder_->getBoolAttr(
1171 instruction->triangular_solve_options().unit_diagonal())));
1172 auto transpose_a = mlir::mhlo::TransposeAttr::get(
1173 builder_->getContext(),
1174 mlir::mhlo::symbolizeTranspose(
1175 TriangularSolveOptions::Transpose_Name(
1176 instruction->triangular_solve_options().transpose_a()))
1177 .getValue());
1178
1179 attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
1180 return func_builder
1181 ->create<mlir::mhlo::TriangularSolveOp>(loc, result_type, operands,
1182 attributes)
1183 .getOperation();
1184 }
1185 case HloOpcode::kReduceScatter: {
1186 auto reduce_scatter = Cast<HloReduceScatterInstruction>(instruction);
1187 attributes.push_back(builder_->getNamedAttr(
1188 "scatter_dimension",
1189 builder_->getI64IntegerAttr(reduce_scatter->scatter_dimension())));
1190 attributes.push_back(
1191 ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_));
1192 if (reduce_scatter->channel_id().has_value())
1193 attributes.push_back(
1194 ConvertChannelHandle(reduce_scatter->channel_id().value()));
1195 auto reduce_scatter_op =
1196 func_builder->create<mlir::mhlo::ReduceScatterOp>(
1197 loc, result_type, operands, attributes);
1198 TF_RETURN_IF_ERROR(ImportAsRegion(*reduce_scatter->to_apply(),
1199 &reduce_scatter_op.computation(),
1200 /*flatten_region_arg_tuple=*/true));
1201
1202 return reduce_scatter_op.getOperation();
1203 }
1204 case HloOpcode::kReduceWindow: {
1205 llvm::SmallVector<Type, 4> return_types = {result_type};
1206 if (mlir::TupleType tuple_ty = result_type.dyn_cast<mlir::TupleType>()) {
1207 return_types = llvm::to_vector<6>(tuple_ty.getTypes());
1208 }
1209 llvm::SmallVector<int64_t, 4> sizes, strides, base_dilations,
1210 win_dilations;
1211 llvm::SmallVector<int64_t, 8> padding;
1212 for (const auto& dim : instruction->window().dimensions()) {
1213 sizes.push_back(dim.size());
1214 strides.push_back(dim.stride());
1215 base_dilations.push_back(dim.base_dilation());
1216 win_dilations.push_back(dim.window_dilation());
1217 padding.push_back(dim.padding_low());
1218 padding.push_back(dim.padding_high());
1219 }
1220 attributes.push_back(builder_->getNamedAttr("window_dimensions",
1221 ConvertDimensions(sizes)));
1222 attributes.push_back(
1223 builder_->getNamedAttr("window_strides", ConvertDimensions(strides)));
1224 attributes.push_back(builder_->getNamedAttr(
1225 "base_dilations", ConvertDimensions(base_dilations)));
1226 attributes.push_back(builder_->getNamedAttr(
1227 "window_dilations", ConvertDimensions(win_dilations)));
1228 attributes.push_back(ConvertPadding(padding));
1229 auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
1230 loc, return_types, operands, attributes);
1231 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
1232 &reduce.body(),
1233 /*flatten_region_arg_tuple=*/true));
1234
1235 // Check if the output needs to be tupled.
1236 if (return_types.size() == 1 && return_types.front() == result_type) {
1237 return reduce.getOperation();
1238 }
1239
1240 return func_builder
1241 ->create<mlir::mhlo::TupleOp>(loc, result_type, reduce.getResults())
1242 .getOperation();
1243 }
1244 case HloOpcode::kMap: {
1245 auto op = func_builder->create<mlir::mhlo::MapOp>(
1246 loc, result_type, operands,
1247 ConvertDimensions(instruction->dimensions()));
1248 TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(),
1249 &op.computation(),
1250 /*flatten_region_arg_tuple=*/true));
1251 return op.getOperation();
1252 }
1253 case HloOpcode::kConvolution: {
1254 llvm::SmallVector<int64_t, 4> strides, lhs_dilations, rhs_dilations;
1255 llvm::SmallVector<bool, 4> reversals;
1256 llvm::SmallVector<int64_t, 8> paddings;
1257 for (const auto& dim : instruction->window().dimensions()) {
1258 strides.push_back(dim.stride());
1259 lhs_dilations.push_back(dim.base_dilation());
1260 rhs_dilations.push_back(dim.window_dilation());
1261 paddings.push_back(dim.padding_low());
1262 paddings.push_back(dim.padding_high());
1263 reversals.push_back(dim.window_reversal());
1264 }
1265
1266 attributes.push_back(
1267 builder_->getNamedAttr("window_strides", Convert(strides)));
1268 attributes.push_back(ConvertPadding(paddings));
1269 attributes.push_back(
1270 builder_->getNamedAttr("lhs_dilation", Convert(lhs_dilations)));
1271 attributes.push_back(
1272 builder_->getNamedAttr("rhs_dilation", Convert(rhs_dilations)));
1273 attributes.push_back(
1274 builder_->getNamedAttr("window_reversal", Convert(reversals)));
1275 attributes.push_back(builder_->getNamedAttr(
1276 "dimension_numbers",
1277 ConvertConvDimensionNumbers(
1278 instruction->convolution_dimension_numbers(), builder_)));
1279 attributes.push_back(builder_->getNamedAttr(
1280 "feature_group_count",
1281 builder_->getI64IntegerAttr(instruction->feature_group_count())));
1282 attributes.push_back(builder_->getNamedAttr(
1283 "batch_group_count",
1284 builder_->getI64IntegerAttr(instruction->batch_group_count())));
1285 attributes.push_back(builder_->getNamedAttr(
1286 "precision_config",
1287 ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
1288
1289 return func_builder
1290 ->create<mlir::mhlo::ConvolutionOp>(loc, result_type, operands,
1291 attributes)
1292 .getOperation();
1293 }
1294
1295 case HloOpcode::kFft: {
1296 auto fft_type = mlir::mhlo::FftTypeAttr::get(
1297 builder_->getContext(),
1298 mlir::mhlo::symbolizeFftType(FftType_Name(instruction->fft_type()))
1299 .getValue());
1300
1301 std::vector<int64_t> fft_length(instruction->fft_length().begin(),
1302 instruction->fft_length().end());
1303
1304 attributes.push_back(builder_->getNamedAttr("fft_type", fft_type));
1305 attributes.push_back(
1306 builder_->getNamedAttr("fft_length", Convert(fft_length)));
1307 return func_builder
1308 ->create<mlir::mhlo::FftOp>(loc, result_type, operands, attributes)
1309 .getOperation();
1310 }
1311
1312 case HloOpcode::kAdd: {
1313 // HLO add ops on PRED elements are actually boolean or, but MHLO dialect
1314 // AddOps on i1 are just addition with overflow; so, we have to implement
1315 // the special behavior of HLO add ops on PRED here by creating an
1316 // arith::OrIOp instead.
1317 if (instruction->shape().element_type() == PRED) {
1318 return func_builder
1319 ->create<mlir::mhlo::OrOp>(loc, result_type, operands, attributes)
1320 .getOperation();
1321 } else {
1322 return func_builder
1323 ->create<mlir::mhlo::AddOp>(loc, result_type, operands, attributes)
1324 .getOperation();
1325 }
1326 }
1327 case HloOpcode::kAfterAll: {
1328 // HLO AfterAll ops without any token input are used to just create a
1329 // token. MHLO has a special op CreateToken for this case.
1330 if (instruction->operands().empty()) {
1331 return func_builder
1332 ->create<mlir::mhlo::CreateTokenOp>(loc, result_type, operands,
1333 attributes)
1334 .getOperation();
1335 } else {
1336 return func_builder
1337 ->create<mlir::mhlo::AfterAllOp>(loc, result_type, operands,
1338 attributes)
1339 .getOperation();
1340 }
1341 }
1342
1343 case HloOpcode::kConvert: {
1344 // Convert to boolean is special, it requires a comparison to 0 instead of
1345 // a truncation to i1, otherwise it is a 1-1 translation.
1346 auto ranked_type = result_type.dyn_cast<mlir::RankedTensorType>();
1347 mlir::IntegerType integer_type =
1348 (ranked_type)
1349 ? ranked_type.getElementType().dyn_cast<mlir::IntegerType>()
1350 : nullptr;
1351 if (!integer_type || integer_type.getWidth() != 1) {
1352 // Simple case: 1-1 mapping.
1353 return {func_builder->create<mlir::mhlo::ConvertOp>(
1354 loc, result_type, operands, attributes)};
1355 }
1356
1357 // Return type is boolean, let's use `operand != 0` instead of Convert.
1358 xla::Shape input_shape = instruction->operand(0)->shape();
1359 TF_ASSIGN_OR_RETURN(mlir::Type type,
1360 ConvertTensorShapeToType<mlir::RankedTensorType>(
1361 input_shape, *func_builder));
1362 auto zero = func_builder->create<mlir::mhlo::ConstantOp>(
1363 loc, func_builder->getZeroAttr(type));
1364 return {func_builder->create<mlir::mhlo::CompareOp>(
1365 loc, operands[0], zero, mlir::mhlo::ComparisonDirection::NE)};
1366 }
1367 case HloOpcode::kOptimizationBarrier: {
1368 llvm::SmallVector<Value> flattened_operands;
1369 llvm::SmallVector<Type> flattened_operand_types;
1370 FlattenTupleType(operands[0].getType(), flattened_operand_types);
1371 FlattenTupleValue(func_builder, loc, operands[0], flattened_operands);
1372
1373 auto op = func_builder->create<mlir::mhlo::OptimizationBarrierOp>(
1374 loc, flattened_operand_types, flattened_operands);
1375
1376 return CreateTupleFromOpResults(func_builder, loc, op.getOperation(),
1377 operands[0].getType());
1378 }
1379 case HloOpcode::kDomain: {
1380 auto domain_kind = mlir::mhlo::symbolizeDomainKind(
1381 instruction->user_side_metadata().Kind());
1382 if (!domain_kind || *domain_kind != mlir::mhlo::DomainKind::sharding) {
1383 return tensorflow::errors::InvalidArgument(
1384 "Invalid domain kind in hlo -> mhlo import. Only 'sharding' is "
1385 "supported");
1386 }
1387 attributes.push_back(builder_->getNamedAttr(
1388 "kind", mlir::mhlo::DomainKindAttr::get(func_builder->getContext(),
1389 *domain_kind)));
1390
1391 // In XLA, DomainMetadata is open-world, but in the proto, it is hardcoded
1392 // to be ShardingMetadata. Thankfully, the only other implementation of
1393 // DomainMetadata is OpName, which is generally used for debugging and
1394 // never for compiling production models.
1395 //
1396 // Since this is hardcoded as such in the proto, we must follow suit.
1397 // TODO(b/208783683): The one improvement we can make on this is to move
1398 // from the a serialized proto representation to a parsable string
1399 auto exit_metadata = ShardingMetadata::ToShardingMetadata(
1400 &instruction->operand_side_metadata());
1401 auto entry_metadata = ShardingMetadata::ToShardingMetadata(
1402 &instruction->user_side_metadata());
1403 attributes.push_back(builder_->getNamedAttr(
1404 "exit_metadata",
1405 builder_->getStringAttr(
1406 (*exit_metadata)->sharding()->ToProto().SerializeAsString())));
1407 attributes.push_back(builder_->getNamedAttr(
1408 "entry_metadata",
1409 builder_->getStringAttr(
1410 (*entry_metadata)->sharding()->ToProto().SerializeAsString())));
1411
1412 return func_builder
1413 ->create<mlir::mhlo::DomainOp>(loc, result_type, operands, attributes)
1414 .getOperation();
1415 }
1416
1417 #define NO_ATTRIBUTE_CASE(hlo_op_code, mlir_op) \
1418 case HloOpcode::hlo_op_code: { \
1419 return func_builder \
1420 ->create<mlir::mhlo::mlir_op>(loc, result_type, operands, attributes) \
1421 .getOperation(); \
1422 }
1423
1424 // broadcast dimensions are never added here because they don't exist as
1425 // part of the HLO instruction. They are only a convenience in the XLA
1426 // builder API.
1427 NO_ATTRIBUTE_CASE(kAbs, AbsOp);
1428 NO_ATTRIBUTE_CASE(kAddDependency, AddDependencyOp);
1429 NO_ATTRIBUTE_CASE(kAnd, AndOp);
1430 NO_ATTRIBUTE_CASE(kAtan2, Atan2Op);
1431 NO_ATTRIBUTE_CASE(kBitcastConvert, BitcastConvertOp);
1432 NO_ATTRIBUTE_CASE(kCbrt, CbrtOp);
1433 NO_ATTRIBUTE_CASE(kClz, ClzOp);
1434 NO_ATTRIBUTE_CASE(kCeil, CeilOp);
1435 NO_ATTRIBUTE_CASE(kClamp, ClampOp);
1436 NO_ATTRIBUTE_CASE(kComplex, ComplexOp);
1437 NO_ATTRIBUTE_CASE(kCos, CosineOp);
1438 NO_ATTRIBUTE_CASE(kDivide, DivOp);
1439 NO_ATTRIBUTE_CASE(kExp, ExpOp);
1440 NO_ATTRIBUTE_CASE(kExpm1, Expm1Op);
1441 NO_ATTRIBUTE_CASE(kFloor, FloorOp);
1442 NO_ATTRIBUTE_CASE(kIsFinite, IsFiniteOp);
1443 NO_ATTRIBUTE_CASE(kImag, ImagOp);
1444 NO_ATTRIBUTE_CASE(kLog, LogOp);
1445 NO_ATTRIBUTE_CASE(kLog1p, Log1pOp);
1446 NO_ATTRIBUTE_CASE(kMaximum, MaxOp);
1447 NO_ATTRIBUTE_CASE(kMinimum, MinOp);
1448 NO_ATTRIBUTE_CASE(kMultiply, MulOp);
1449 NO_ATTRIBUTE_CASE(kNegate, NegOp);
1450 NO_ATTRIBUTE_CASE(kNot, NotOp);
1451 NO_ATTRIBUTE_CASE(kOr, OrOp);
1452 NO_ATTRIBUTE_CASE(kPartitionId, PartitionIdOp);
1453 NO_ATTRIBUTE_CASE(kPopulationCount, PopulationCountOp);
1454 NO_ATTRIBUTE_CASE(kPower, PowOp);
1455 NO_ATTRIBUTE_CASE(kReal, RealOp);
1456 NO_ATTRIBUTE_CASE(kRemainder, RemOp);
1457 NO_ATTRIBUTE_CASE(kReplicaId, ReplicaIdOp);
1458 NO_ATTRIBUTE_CASE(kLogistic, LogisticOp);
1459 // The dimensions attribute is not present on the HLO Reshape
1460 // instruction. If dimensions are non-default, the XLA builder
1461 // implements it as a separate transpose.
1462 NO_ATTRIBUTE_CASE(kReshape, ReshapeOp);
1463 NO_ATTRIBUTE_CASE(kRoundNearestAfz, RoundOp);
1464 NO_ATTRIBUTE_CASE(kRoundNearestEven, RoundNearestEvenOp);
1465 NO_ATTRIBUTE_CASE(kRsqrt, RsqrtOp);
1466 NO_ATTRIBUTE_CASE(kSelect, SelectOp);
1467 NO_ATTRIBUTE_CASE(kShiftLeft, ShiftLeftOp);
1468 NO_ATTRIBUTE_CASE(kShiftRightArithmetic, ShiftRightArithmeticOp);
1469 NO_ATTRIBUTE_CASE(kShiftRightLogical, ShiftRightLogicalOp);
1470 NO_ATTRIBUTE_CASE(kSign, SignOp);
1471 NO_ATTRIBUTE_CASE(kSin, SineOp);
1472 NO_ATTRIBUTE_CASE(kSqrt, SqrtOp);
1473 NO_ATTRIBUTE_CASE(kSubtract, SubtractOp);
1474 NO_ATTRIBUTE_CASE(kTanh, TanhOp);
1475 NO_ATTRIBUTE_CASE(kTuple, TupleOp);
1476 NO_ATTRIBUTE_CASE(kXor, XorOp);
1477 // TODO(b/129422361) Copy needs special handling because it is not
1478 // defined in tensorflow/compiler/xla/client/xla_builder.h. See
1479 // operation semantics in
1480 // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
1481 NO_ATTRIBUTE_CASE(kCopy, CopyOp);
1482
1483 #undef NO_ATTRIBUTE_CASE
1484
1485 case HloOpcode::kFusion: {
1486 // Flatten the tuple-typed operands.
1487 llvm::SmallVector<Value> flattened_operands;
1488 for (auto& operand : operands)
1489 FlattenTupleValue(func_builder, loc, operand, flattened_operands);
1490
1491 // Flatten the return type if they are tuple-typed.
1492 llvm::SmallVector<Type> flattened_ret_types;
1493 FlattenTupleType(result_type, flattened_ret_types);
1494
1495 auto fusion_kind = mlir::mhlo::symbolizeFusionKind(
1496 xla::ToString(instruction->fusion_kind()));
1497 auto fusion = func_builder->create<mlir::mhlo::FusionOp>(
1498 loc, flattened_ret_types, flattened_operands,
1499 mlir::mhlo::FusionKindAttr::get(func_builder->getContext(),
1500 fusion_kind.getValue()));
1501 TF_RETURN_IF_ERROR(ImportAsRegion(
1502 *instruction->fused_instructions_computation(),
1503 &fusion.fused_computation(), /*flatten_region_arg_tuple=*/true));
1504
1505 return CreateTupleFromOpResults(func_builder, loc, fusion.getOperation(),
1506 result_type);
1507 }
1508 case HloOpcode::kBitcast: {
1509 auto bitcast = func_builder->create<mlir::mhlo::BitcastOp>(
1510 loc, result_type, operands, attributes);
1511 // Store the source and result layout as attributes. Although the MHLO
1512 // Bitcast operates on tensors, these layouts are relevant as they define
1513 // the mapping between the elements of the source and result.
1514 SetLayoutForMlir(bitcast, instruction->shape(), "result_layout");
1515 SetLayoutForMlir(bitcast, instruction->operand(0)->shape(),
1516 "source_layout");
1517 return bitcast.getOperation();
1518 }
1519 case HloOpcode::kReducePrecision: {
1520 auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
1521 loc, result_type, operands[0], attributes);
1522 op.exponent_bitsAttr(func_builder->getIntegerAttr(
1523 func_builder->getI32Type(), instruction->exponent_bits()));
1524 op.mantissa_bitsAttr(func_builder->getIntegerAttr(
1525 func_builder->getI32Type(), instruction->mantissa_bits()));
1526 return op.getOperation();
1527 }
1528 default: {
1529 mlir::OperationState result(loc, "mhlo.unknown");
1530 result.addOperands(operands);
1531 result.addTypes(result_type);
1532 for (auto attr : attributes) {
1533 result.attributes.push_back(attr);
1534 }
1535
1536 return func_builder->create(result);
1537 }
1538 }
1539 }
1540
SetXlaShape(mlir::Operation * op,const Shape & shape)1541 void SetXlaShape(mlir::Operation* op, const Shape& shape) {
1542 op->setAttr("xla_shape",
1543 mlir::Builder(op->getContext())
1544 .getStringAttr(shape.ToString(/*print_layout=*/true)));
1545 }
1546
ImportInstructionWithLayout(const HloInstruction * instruction,const llvm::SmallVectorImpl<mlir::Value> & operands,mlir::OpBuilder * func_builder,DynamicShapeHandlingMode mode)1547 StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionWithLayout(
1548 const HloInstruction* instruction,
1549 const llvm::SmallVectorImpl<mlir::Value>& operands,
1550 mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) {
1551 TF_ASSIGN_OR_RETURN(
1552 mlir::Operation * op,
1553 ImportInstructionImpl(instruction, operands, func_builder, mode));
1554 if (op == nullptr) return op;
1555
1556 // See MlirToHloConversionOptions for more about layouts.
1557 //
1558 // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
1559 // in physical minor-to-major order.
1560 if (instruction->shape().IsArray()) {
1561 if (instruction->shape().has_layout() &&
1562 !instruction->shape().layout().minor_to_major().empty() &&
1563 instruction->shape().layout() !=
1564 LayoutUtil::MakeDescendingLayout(
1565 instruction->shape().dimensions().size())) {
1566 SetXlaShape(op, instruction->shape());
1567 }
1568 } else {
1569 SetXlaShape(op, instruction->shape());
1570 }
1571 return op;
1572 }
1573
GetOperands(const HloInstruction * instruction)1574 StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
1575 const HloInstruction* instruction) {
1576 llvm::SmallVector<mlir::Value, 4> operands;
1577 for (const auto& operand : instruction->operands()) {
1578 auto input_it = instruction_value_map_.find(operand);
1579 if (input_it == instruction_value_map_.end()) {
1580 return tensorflow::errors::Internal(
1581 absl::StrCat("Could not find input value: ", operand->name(),
1582 " for instruction ", instruction->name()));
1583 }
1584 operands.push_back(input_it->second);
1585 }
1586 return operands;
1587 }
1588
GetMlirTypes(const std::vector<HloInstruction * > & instructions,llvm::SmallVectorImpl<mlir::Type> * types)1589 tensorflow::Status HloFunctionImporter::GetMlirTypes(
1590 const std::vector<HloInstruction*>& instructions,
1591 llvm::SmallVectorImpl<mlir::Type>* types) {
1592 for (auto instruction : instructions) {
1593 TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
1594 instruction->shape(), *builder_));
1595 types->push_back(ret_type);
1596 }
1597 return ::tensorflow::OkStatus();
1598 }
1599
GetMlirValue(const HloInstruction * instruction)1600 StatusOr<Value> HloFunctionImporter::GetMlirValue(
1601 const HloInstruction* instruction) {
1602 auto lookup = instruction_value_map_.find(instruction);
1603 if (lookup != instruction_value_map_.end()) {
1604 return lookup->second;
1605 }
1606
1607 return tensorflow::errors::Internal(absl::StrCat(
1608 "Unable to find value for input: ", instruction->ToString()));
1609 }
1610
ConvertComparisonDirection(ComparisonDirection direction)1611 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
1612 ComparisonDirection direction) {
1613 return builder_->getNamedAttr(
1614 "comparison_direction",
1615 mlir::mhlo::ComparisonDirectionAttr::get(
1616 builder_->getContext(), mlir::mhlo::symbolizeComparisonDirection(
1617 ComparisonDirectionToString(direction))
1618 .getValue()));
1619 }
1620
ConvertComparisonType(Comparison::Type type)1621 mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
1622 Comparison::Type type) {
1623 return builder_->getNamedAttr(
1624 "compare_type",
1625 mlir::mhlo::ComparisonTypeAttr::get(
1626 builder_->getContext(),
1627 mlir::mhlo::symbolizeComparisonType(ComparisonTypeToString(type))
1628 .getValue()));
1629 }
1630
ConvertDimensions(absl::Span<const int64_t> op_dimensions)1631 mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
1632 absl::Span<const int64_t> op_dimensions) {
1633 llvm::SmallVector<APInt, 8> dimensions;
1634 dimensions.reserve(op_dimensions.size());
1635 for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
1636
1637 return DenseIntElementsAttr::get(
1638 RankedTensorType::get(dimensions.size(), builder_->getIntegerType(64)),
1639 dimensions);
1640 }
1641
Convert(llvm::ArrayRef<int64_t> elements)1642 mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
1643 llvm::ArrayRef<int64_t> elements) {
1644 return DenseIntElementsAttr::get(
1645 RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
1646 elements);
1647 }
1648
Convert(llvm::ArrayRef<bool> elements)1649 mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
1650 llvm::ArrayRef<bool> elements) {
1651 return DenseIntElementsAttr::get(
1652 RankedTensorType::get(elements.size(), builder_->getI1Type()), elements);
1653 }
1654
ConvertPadding(llvm::ArrayRef<int64_t> padding)1655 mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
1656 llvm::ArrayRef<int64_t> padding) {
1657 auto ty =
1658 mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
1659 builder_->getIntegerType(64));
1660 auto attr = DenseIntElementsAttr::get(ty, padding);
1661 return builder_->getNamedAttr("padding", attr);
1662 }
1663
ConvertSourceTargetPairs(const std::vector<std::pair<int64_t,int64_t>> & source_target_pairs,mlir::Builder * builder)1664 mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
1665 const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
1666 mlir::Builder* builder) {
1667 std::vector<int64_t> attr(source_target_pairs.size() * 2);
1668 for (const auto& p : llvm::enumerate(source_target_pairs)) {
1669 attr[2 * p.index()] = p.value().first;
1670 attr[2 * p.index() + 1] = p.value().second;
1671 }
1672 auto type = mlir::RankedTensorType::get(
1673 {static_cast<int64_t>(attr.size() / 2), 2}, builder->getIntegerType(64));
1674 return builder->getNamedAttr("source_target_pairs",
1675 DenseIntElementsAttr::get(type, attr));
1676 }
1677
ConvertReplicaGroups(absl::Span<const ReplicaGroup> replica_groups,mlir::Builder * builder)1678 mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
1679 absl::Span<const ReplicaGroup> replica_groups, mlir::Builder* builder) {
1680 const int64_t num_groups = replica_groups.size();
1681 // Replica groups in HLO can be non-uniform in size, for example:
1682 // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D
1683 // tensor, pad the smaller sized replica groups with -1.
1684 const int64_t group_size = absl::c_accumulate(
1685 replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) {
1686 return std::max<int64_t>(current, g.replica_ids_size());
1687 });
1688 // Initialize all elements to -1 to support non-uniform replica groups.
1689 std::vector<int64_t> attr(num_groups * group_size, -1);
1690 for (int i = 0; i < num_groups; ++i) {
1691 int index = i * group_size;
1692 for (const int64_t& id : replica_groups[i].replica_ids())
1693 attr[index++] = id;
1694 }
1695 auto type = mlir::RankedTensorType::get({num_groups, group_size},
1696 builder->getIntegerType(64));
1697 return builder->getNamedAttr("replica_groups",
1698 DenseIntElementsAttr::get(type, attr));
1699 }
1700
ConvertChannelHandle(std::optional<int64_t> channel_id)1701 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
1702 std::optional<int64_t> channel_id) {
1703 xla::ChannelHandle channel_handle;
1704 if (channel_id) channel_handle.set_handle(*channel_id);
1705 return ConvertChannelHandle(channel_handle);
1706 }
1707
ConvertChannelHandle(const xla::ChannelHandle & channel)1708 mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
1709 const xla::ChannelHandle& channel) {
1710 return builder_->getNamedAttr(
1711 "channel_handle", mlir::mhlo::ChannelHandleAttr::get(
1712 context_, channel.handle(), channel.type()));
1713 }
1714
ConvertUseGlobalDeviceIds()1715 mlir::NamedAttribute HloFunctionImporter::ConvertUseGlobalDeviceIds() {
1716 return builder_->getNamedAttr("use_global_device_ids",
1717 builder_->getUnitAttr());
1718 }
1719
SetLayoutForMlir(mlir::Operation * op,const Shape & shape,llvm::StringRef attr_name)1720 void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op,
1721 const Shape& shape,
1722 llvm::StringRef attr_name) {
1723 llvm::SmallVector<int64_t, 4> minor_to_major(
1724 shape.layout().minor_to_major().begin(),
1725 shape.layout().minor_to_major().end());
1726 op->setAttr(
1727 attr_name,
1728 mlir::Builder(op->getContext()).getIndexTensorAttr(minor_to_major));
1729 }
1730
ConvertShapeToMlirLayout(const xla::Shape & shape,llvm::SmallVectorImpl<mlir::Attribute> & flattened_attr)1731 Status HloFunctionImporter::ConvertShapeToMlirLayout(
1732 const xla::Shape& shape,
1733 llvm::SmallVectorImpl<mlir::Attribute>& flattened_attr) {
1734 if (shape.IsToken()) {
1735 return ::tensorflow::OkStatus();
1736 }
1737 if (shape.IsTuple()) {
1738 std::vector<mlir::Attribute> tuple_layouts;
1739 for (int i = 0; i < shape.tuple_shapes_size(); i++) {
1740 TF_RETURN_IF_ERROR(
1741 ConvertShapeToMlirLayout(shape.tuple_shapes(i), flattened_attr));
1742 }
1743 return ::tensorflow::OkStatus();
1744 }
1745 if (shape.IsArray()) {
1746 const xla::Layout l = shape.layout();
1747 std::vector<mlir::Attribute> minor_to_major;
1748 for (int64_t i : l.minor_to_major()) {
1749 minor_to_major.push_back(builder_->getI64IntegerAttr(i));
1750 }
1751 llvm::ArrayRef<mlir::Attribute> array_ref(minor_to_major);
1752 flattened_attr.push_back(builder_->getArrayAttr(array_ref));
1753 return ::tensorflow::OkStatus();
1754 }
1755 return tensorflow::errors::Internal("Couldn't convert layout.");
1756 }
1757
1758 } // namespace xla
1759