1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
17
18 #include <memory>
19 #include <optional>
20 #include <string>
21
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/MemoryBuffer.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "llvm/Support/SourceMgr.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
33 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
34 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
36 #include "mlir/IR/Attributes.h" // from @llvm-project
37 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
40 #include "mlir/IR/Location.h" // from @llvm-project
41 #include "mlir/IR/MLIRContext.h" // from @llvm-project
42 #include "mlir/IR/Matchers.h" // from @llvm-project
43 #include "mlir/IR/Operation.h" // from @llvm-project
44 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
45 #include "mlir/IR/UseDefLists.h" // from @llvm-project
46 #include "mlir/Pass/Pass.h" // from @llvm-project
47 #include "mlir/Pass/PassManager.h" // from @llvm-project
48 #include "mlir/Support/LogicalResult.h" // from @llvm-project
49 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/utils/name_utils.h"
51 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
52 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
53 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
54 #include "tensorflow/compiler/xla/client/lib/matrix.h"
55 #include "tensorflow/compiler/xla/client/lib/quantize.h"
56 #include "tensorflow/compiler/xla/client/lib/slicing.h"
57 #include "tensorflow/compiler/xla/client/xla_builder.h"
58 #include "tensorflow/compiler/xla/comparison_util.h"
59 #include "tensorflow/compiler/xla/literal_util.h"
60 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
61 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
62 #include "tensorflow/compiler/xla/service/hlo_module.h"
63 #include "tensorflow/compiler/xla/service/hlo_parser.h"
64 #include "tensorflow/compiler/xla/shape_util.h"
65 #include "tensorflow/compiler/xla/status_macros.h"
66 #include "tensorflow/compiler/xla/xla_data.pb.h"
67 #include "tensorflow/core/framework/tensor_shape.h"
68 #include "tensorflow/core/framework/types.pb.h"
69 #include "tensorflow/core/platform/errors.h"
70 #include "tensorflow/stream_executor/lib/statusor.h"
71
72 using ::int64_t;
73 using ::stream_executor::port::StatusOr;
74 using ::tensorflow::int16;
75 using ::tensorflow::int32;
76 using ::tensorflow::int8;
77 using ::tensorflow::uint16;
78 using ::tensorflow::uint32;
79 using ::tensorflow::uint64;
80 using ::tensorflow::uint8;
81
82 constexpr char kShapeIndicesAttr[] = "shape_indices";
83 constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
84 constexpr char kShardingAttr[] = "mhlo.sharding";
85 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
86 constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas";
87
88 // Array attribute. Same shape as infeed result, but contains a
89 // minor_to_major array for every tensor.
90 constexpr char kLayoutAttr[] = "layout";
91 constexpr char kDefaultLayoutAttrName[] = "xla_shape";
92
93 // Passes through everything except for unique_ptr, on which it calls get().
94 // This exists to allow the generated code to call XLA functions that take a raw
95 // pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
96 // as a pointer and there is otherwise no way to avoid a memory leak.
97 template <typename T>
Unwrap(T t)98 T Unwrap(T t) {
99 return t;
100 }
101
102 template <typename T>
Unwrap(const std::unique_ptr<T> & t)103 T* Unwrap(const std::unique_ptr<T>& t) {
104 return t.get();
105 }
106
GetXlaOp(mlir::Value val,const llvm::DenseMap<mlir::Value,xla::XlaOp> & val_map,xla::XlaOp * result,mlir::Operation * op)107 static mlir::LogicalResult GetXlaOp(
108 mlir::Value val, const llvm::DenseMap<mlir::Value, xla::XlaOp>& val_map,
109 xla::XlaOp* result, mlir::Operation* op) {
110 auto iter = val_map.find(val);
111 if (iter == val_map.end()) {
112 return op->emitOpError(
113 "requires all operands to be defined in the parent region for export");
114 }
115 *result = iter->second;
116 return mlir::success();
117 }
118
IsBoundedOrStatic(mlir::Type ty)119 bool IsBoundedOrStatic(mlir::Type ty) {
120 auto ranked_ty = ty.dyn_cast_or_null<mlir::RankedTensorType>();
121 if (!ranked_ty) return false;
122
123 if (ranked_ty.hasStaticShape()) return true;
124
125 auto encoding = ranked_ty.getEncoding()
126 .dyn_cast_or_null<mlir::mhlo::TypeExtensionsAttr>();
127 if (!encoding || encoding.getBounds().empty()) return false;
128
129 int64_t rank = ranked_ty.getRank();
130 for (int64_t dim = 0; dim < rank; ++dim) {
131 if (ranked_ty.isDynamicDim(dim) &&
132 encoding.getBounds()[dim] == mlir::ShapedType::kDynamicSize)
133 return false;
134 }
135 return true;
136 }
137
138 // Convert APInt into an int.
139 // TODO(hpucha): This should be consolidated into a general place.
ConvertAPInt(llvm::APInt i)140 static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
141
Convertuint32_t(uint32_t i)142 static uint32_t Convertuint32_t(uint32_t i) { return i; }
Convertuint64_t(uint64_t i)143 static uint64_t Convertuint64_t(uint64_t i) { return i; }
144
145 // Convert APFloat to double.
ConvertAPFloat(llvm::APFloat value)146 static double ConvertAPFloat(llvm::APFloat value) {
147 const auto& semantics = value.getSemantics();
148 bool losesInfo = false;
149 if (&semantics != &llvm::APFloat::IEEEdouble())
150 value.convert(llvm::APFloat::IEEEdouble(),
151 llvm::APFloat::rmNearestTiesToEven, &losesInfo);
152 return value.convertToDouble();
153 }
154
Convertbool(bool value)155 static inline bool Convertbool(bool value) { return value; }
156
ConvertStringRef(mlir::StringRef value)157 static absl::string_view ConvertStringRef(mlir::StringRef value) {
158 return {value.data(), value.size()};
159 }
160
ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr)161 static std::vector<int64_t> ConvertDenseIntAttr(
162 mlir::DenseIntElementsAttr attr) {
163 auto values = attr.getValues<int64_t>();
164 return {values.begin(), values.end()};
165 }
166
ConvertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> attr)167 static std::vector<int64_t> ConvertDenseIntAttr(
168 llvm::Optional<mlir::DenseIntElementsAttr> attr) {
169 if (!attr) return {};
170 return ConvertDenseIntAttr(*attr);
171 }
172
173 // Converts the broadcast_dimensions attribute into a vector of dimension
174 // numbers (empty if the attribute is absent).
Convert_broadcast_dimensions(llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions)175 static std::vector<int64_t> Convert_broadcast_dimensions(
176 llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions) {
177 if (!broadcast_dimensions.has_value()) return {};
178
179 return ConvertDenseIntAttr(*broadcast_dimensions);
180 }
181
182 // Converts StringRef to xla FftType enum
Convert_fft_type(mlir::mhlo::FftType fft_type)183 static xla::FftType Convert_fft_type(mlir::mhlo::FftType fft_type) {
184 xla::FftType fft_type_enum;
185 // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse'
186 // call below should never return false.
187 if (!FftType_Parse(std::string(mlir::mhlo::stringifyFftType(fft_type)),
188 &fft_type_enum))
189 return xla::FftType::FFT;
190 return fft_type_enum;
191 }
192
Convert_padding(llvm::Optional<mlir::DenseIntElementsAttr> padding)193 static std::vector<std::pair<int64_t, int64_t>> Convert_padding(
194 llvm::Optional<mlir::DenseIntElementsAttr> padding) {
195 return xla::ConvertNx2Attribute(padding).ValueOrDie();
196 }
197
Convert_use_global_device_ids(llvm::Optional<bool> use_global_device_ids)198 static std::optional<bool> Convert_use_global_device_ids(
199 llvm::Optional<bool> use_global_device_ids) {
200 if (!use_global_device_ids) return {};
201 return *use_global_device_ids;
202 }
203
Convert_source_target_pairs(llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs)204 static std::vector<std::pair<int64_t, int64_t>> Convert_source_target_pairs(
205 llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
206 return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
207 }
208
Convert_replica_groups(mlir::DenseIntElementsAttr groups)209 static std::vector<xla::ReplicaGroup> Convert_replica_groups(
210 mlir::DenseIntElementsAttr groups) {
211 return xla::ConvertReplicaGroups(groups).ValueOrDie();
212 }
213
214 // Converts types and corresponding layouts into xla shapes with layouts.
ConvertTypesToShapesWithLayout(mlir::TypeRange value_types,mlir::ArrayAttr layouts)215 static std::vector<xla::Shape> ConvertTypesToShapesWithLayout(
216 mlir::TypeRange value_types, mlir::ArrayAttr layouts) {
217 std::vector<xla::Shape> shapes_with_layout;
218 for (auto type_and_layout : llvm::zip(value_types, layouts)) {
219 mlir::Type type = std::get<0>(type_and_layout);
220 mlir::Attribute layout = std::get<1>(type_and_layout);
221 assert(!type.isa<mlir::TupleType>() &&
222 "Exporting layout for tuples is not implemented yet");
223 shapes_with_layout.emplace_back(xla::TypeToShape(type));
224 auto& shape = shapes_with_layout.back();
225 shape.mutable_layout()->clear_minor_to_major();
226 for (auto l : layout.cast<mlir::DenseIntElementsAttr>()) {
227 shape.mutable_layout()->mutable_minor_to_major()->push_back(
228 l.getSExtValue());
229 }
230 }
231 return shapes_with_layout;
232 }
233
234 // CustomCallOp result can be of tuple type to pack multiple results into one
235 // value. If the custom call result is a tuple, then result layouts represent
236 // the layout of each element of the tuple. Nested tuples are currently not
237 // supported for export.
GetCustomCallResultShapeWithLayout(mlir::Type type,mlir::ArrayAttr layouts)238 static xla::Shape GetCustomCallResultShapeWithLayout(mlir::Type type,
239 mlir::ArrayAttr layouts) {
240 auto tuple_type = type.dyn_cast<mlir::TupleType>();
241 if (!tuple_type) return ConvertTypesToShapesWithLayout({type}, layouts)[0];
242
243 std::vector<xla::Shape> shapes_with_layouts =
244 ConvertTypesToShapesWithLayout(tuple_type.getTypes(), layouts);
245 return xla::ShapeUtil::MakeTupleShape(shapes_with_layouts);
246 }
247
248 // Converts StringRef to xla Transpose enum.
Convert_transpose_a(mlir::mhlo::Transpose transpose)249 static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
250 mlir::mhlo::Transpose transpose) {
251 return xla::ConvertTranspose(mlir::mhlo::stringifyTranspose(transpose))
252 .ValueOrDie();
253 }
254
ExtractLayout(mlir::Operation * op,int rank,llvm::StringRef attr_name=kDefaultLayoutAttrName)255 static xla::Layout ExtractLayout(
256 mlir::Operation* op, int rank,
257 llvm::StringRef attr_name = kDefaultLayoutAttrName) {
258 if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(attr_name)) {
259 llvm::SmallVector<int64_t, 4> minor_to_major;
260 DCHECK_EQ(rank, attr.size());
261 minor_to_major.reserve(attr.size());
262 for (const llvm::APInt& i : attr) {
263 minor_to_major.push_back(i.getZExtValue());
264 }
265 return xla::LayoutUtil::MakeLayout(minor_to_major);
266 }
267 return xla::LayoutUtil::MakeDescendingLayout(rank);
268 }
269
ExtractXlaShape(mlir::Operation * op)270 static xla::Shape ExtractXlaShape(mlir::Operation* op) {
271 if (auto attr = op->getAttrOfType<mlir::StringAttr>(kDefaultLayoutAttrName)) {
272 return *xla::ParseShape(
273 absl::string_view(attr.getValue().data(), attr.getValue().size()));
274 } else {
275 std::vector<xla::Shape> subshapes;
276 for (mlir::Value result : op->getResults()) {
277 subshapes.push_back(xla::TypeToShape(result.getType()));
278 }
279 if (subshapes.size() > 1) {
280 return xla::ShapeUtil::MakeTupleShape(subshapes);
281 }
282 return subshapes[0];
283 }
284 }
285
286 #define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
287 static std::vector<int64_t> Convert_##attribute( \
288 llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
289 return ConvertDenseIntAttr(attribute); \
290 }
291
292 I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes);
293 I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
294 I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
295 I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
296 I64_ELEMENTS_ATTR_TO_VECTOR(strides);
297 I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
298 I64_ELEMENTS_ATTR_TO_VECTOR(fft_length);
299 I64_ELEMENTS_ATTR_TO_VECTOR(dimensions);
300 I64_ELEMENTS_ATTR_TO_VECTOR(window_strides);
301 I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation);
302 I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation);
303
304 #undef I64_ELEMENTS_ATTR_TO_VECTOR
305
306 #define BOOL_ELEMENTS_ATTR_TO_VECTOR(attribute) \
307 static std::vector<bool> Convert_##attribute( \
308 llvm::Optional<mlir::DenseElementsAttr> attribute) { \
309 if (!attribute) return {}; \
310 auto values = attribute->getValues<bool>(); \
311 return {values.begin(), values.end()}; \
312 }
313
314 BOOL_ELEMENTS_ATTR_TO_VECTOR(window_reversal);
315
316 #undef BOOL_ELEMENTS_ATTR_TO_VECTOR
317
Convert_ArrayRef(llvm::ArrayRef<int64_t> values)318 static std::vector<int64_t> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
319 return {values.begin(), values.end()};
320 }
321
322 // Converts the precision config array of strings attribute into the
323 // corresponding XLA proto. All the strings are assumed to be valid names of the
324 // Precision enum. This should have been checked in the op verify method.
Convert_precision_config(llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr)325 static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
326 llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
327 if (!optional_precision_config_attr.has_value()) return nullptr;
328
329 auto precision_config = std::make_unique<xla::PrecisionConfig>();
330 for (auto attr : optional_precision_config_attr.getValue()) {
331 xla::PrecisionConfig::Precision p;
332 auto operand_precision =
333 mlir::mhlo::stringifyPrecision(
334 attr.cast<mlir::mhlo::PrecisionAttr>().getValue())
335 .str();
336 // TODO(jpienaar): Update this to ensure this is captured by verify.
337 if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
338 precision_config->add_operand_precision(p);
339 } else {
340 auto* context = attr.getContext();
341 mlir::emitError(mlir::UnknownLoc::get(context))
342 << "unexpected operand precision " << operand_precision;
343 return nullptr;
344 }
345 }
346
347 return precision_config;
348 }
349
Convert_dot_dimension_numbers(mlir::mhlo::DotDimensionNumbersAttr dot_dimension_numbers_attr)350 static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
351 mlir::mhlo::DotDimensionNumbersAttr dot_dimension_numbers_attr) {
352 xla::DotDimensionNumbers dot_dimension_numbers;
353
354 auto rhs_contracting_dimensions =
355 dot_dimension_numbers_attr.getRhsContractingDimensions();
356 auto lhs_contracting_dimensions =
357 dot_dimension_numbers_attr.getLhsContractingDimensions();
358 auto rhs_batch_dimensions =
359 dot_dimension_numbers_attr.getRhsBatchingDimensions();
360 auto lhs_batch_dimensions =
361 dot_dimension_numbers_attr.getLhsBatchingDimensions();
362
363 for (const auto& val : rhs_contracting_dimensions) {
364 dot_dimension_numbers.add_rhs_contracting_dimensions(val);
365 }
366 for (const auto& val : lhs_contracting_dimensions) {
367 dot_dimension_numbers.add_lhs_contracting_dimensions(val);
368 }
369
370 for (const auto& val : rhs_batch_dimensions) {
371 dot_dimension_numbers.add_rhs_batch_dimensions(val);
372 }
373
374 for (const auto& val : lhs_batch_dimensions) {
375 dot_dimension_numbers.add_lhs_batch_dimensions(val);
376 }
377
378 return dot_dimension_numbers;
379 }
380
Convert_dimension_numbers(mlir::mhlo::ConvDimensionNumbersAttr input)381 static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
382 mlir::mhlo::ConvDimensionNumbersAttr input) {
383 return xla::ConvertConvDimensionNumbers(input);
384 }
385
Convert_channel_handle(mlir::mhlo::ChannelHandleAttr attr)386 xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandleAttr attr) {
387 xla::ChannelHandle channel_handle;
388 channel_handle.set_handle(attr.getHandle());
389 channel_handle.set_type(
390 static_cast<xla::ChannelHandle::ChannelType>(attr.getType()));
391 return channel_handle;
392 }
393
Convert_channel_handle(llvm::Optional<mlir::mhlo::ChannelHandleAttr> attr)394 std::optional<xla::ChannelHandle> Convert_channel_handle(
395 llvm::Optional<mlir::mhlo::ChannelHandleAttr> attr) {
396 if (!attr.has_value()) return std::nullopt;
397 return Convert_channel_handle(attr.getValue());
398 }
399
400 // Converts the comparison_direction string attribute into the XLA enum. The
401 // string is assumed to correspond to exactly one of the allowed strings
402 // representing the enum. This should have been checked in the op verify method.
Convert_comparison_direction(llvm::StringRef comparison_direction_string)403 static xla::ComparisonDirection Convert_comparison_direction(
404 llvm::StringRef comparison_direction_string) {
405 return xla::StringToComparisonDirection(comparison_direction_string.str())
406 .ValueOrDie();
407 }
408
Convert_dimension_numbers(mlir::mhlo::GatherDimensionNumbersAttr input)409 static xla::GatherDimensionNumbers Convert_dimension_numbers(
410 mlir::mhlo::GatherDimensionNumbersAttr input) {
411 xla::GatherDimensionNumbers output;
412
413 auto offset_dims = input.getOffsetDims();
414 std::copy(offset_dims.begin(), offset_dims.end(),
415 tensorflow::protobuf::RepeatedFieldBackInserter(
416 output.mutable_offset_dims()));
417
418 auto collapsed_slice_dims = input.getCollapsedSliceDims();
419 std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
420 tensorflow::protobuf::RepeatedFieldBackInserter(
421 output.mutable_collapsed_slice_dims()));
422
423 auto start_index_map = input.getStartIndexMap();
424 std::copy(start_index_map.begin(), start_index_map.end(),
425 tensorflow::protobuf::RepeatedFieldBackInserter(
426 output.mutable_start_index_map()));
427
428 output.set_index_vector_dim(input.getIndexVectorDim());
429 return output;
430 }
431
Convert_scatter_dimension_numbers(mlir::mhlo::ScatterDimensionNumbersAttr input)432 static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
433 mlir::mhlo::ScatterDimensionNumbersAttr input) {
434 xla::ScatterDimensionNumbers output;
435
436 auto update_window_dims = input.getUpdateWindowDims();
437 std::copy(update_window_dims.begin(), update_window_dims.end(),
438 tensorflow::protobuf::RepeatedFieldBackInserter(
439 output.mutable_update_window_dims()));
440
441 auto inserted_window_dims = input.getInsertedWindowDims();
442 std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
443 tensorflow::protobuf::RepeatedFieldBackInserter(
444 output.mutable_inserted_window_dims()));
445
446 auto scatter_dims_to_operand_dims = input.getScatterDimsToOperandDims();
447 std::copy(scatter_dims_to_operand_dims.begin(),
448 scatter_dims_to_operand_dims.end(),
449 tensorflow::protobuf::RepeatedFieldBackInserter(
450 output.mutable_scatter_dims_to_operand_dims()));
451
452 output.set_index_vector_dim(input.getIndexVectorDim());
453 return output;
454 }
455
456 // Extracts sharding from attribute string.
CreateOpShardingFromStringRef(llvm::StringRef sharding)457 static std::optional<xla::OpSharding> CreateOpShardingFromStringRef(
458 llvm::StringRef sharding) {
459 xla::OpSharding sharding_proto;
460 if (!sharding_proto.ParseFromString(sharding.str())) return std::nullopt;
461 return sharding_proto;
462 }
463
464 // Returns an OpSharding proto from the "sharding" attribute of the op. If the
465 // op doesn't have a sharding attribute or the sharding attribute is invalid,
466 // returns std::nullopt.
CreateOpShardingFromAttribute(mlir::Operation * op)467 static std::optional<xla::OpSharding> CreateOpShardingFromAttribute(
468 mlir::Operation* op) {
469 auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
470 if (!sharding) return std::nullopt;
471 return CreateOpShardingFromStringRef(sharding.getValue());
472 }
473
474 // Returns a FrontendAttributes proto from the "frontend_attributes" attribute
475 // of the op. An empty FrontendAttributes proto is returned if an op does not
476 // have frontend attributes.
CreateOpFrontendAttributesFromAttribute(mlir::Operation * op)477 static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
478 mlir::Operation* op) {
479 xla::FrontendAttributes frontend_attributes;
480 auto frontend_attributes_dict =
481 op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
482
483 if (!frontend_attributes_dict) return frontend_attributes;
484
485 for (const auto& attr : frontend_attributes_dict)
486 if (auto value_str_attr = attr.getValue().dyn_cast<mlir::StringAttr>())
487 frontend_attributes.mutable_map()->insert(
488 {attr.getName().str(), value_str_attr.getValue().str()});
489
490 return frontend_attributes;
491 }
492
493 // Returns a OpMetadata proto based on the location of the op. If the location
494 // is unknown, an empty proto is returned. `op_name` are populated with the op
495 // location (converted). FileLineColLoc locations are populated by taking the
496 // file name and line number, and populating `source_file` and `source_line`
497 // respectively.
CreateOpMetadataFromLocation(mlir::Operation * op,mlir::MlirToHloConversionOptions options)498 static xla::OpMetadata CreateOpMetadataFromLocation(
499 mlir::Operation* op, mlir::MlirToHloConversionOptions options) {
500 xla::OpMetadata metadata;
501 mlir::Location loc = op->getLoc();
502 if (loc.isa<mlir::UnknownLoc>()) return metadata;
503
504 std::string name = mlir::GetNameFromLoc(loc);
505 if (options.legalize_node_names) {
506 mlir::LegalizeNodeName(name);
507 }
508 metadata.set_op_name(name);
509 std::string op_type = mlir::GetOpTypeFromLoc(loc);
510 mlir::LegalizeNodeName(op_type);
511 metadata.set_op_type(op_type);
512
513 if (auto name_loc = op->getLoc().dyn_cast<mlir::NameLoc>()) {
514 loc = name_loc.getChildLoc();
515 if (loc.isa<mlir::UnknownLoc>()) return metadata;
516 }
517
518 if (auto file_line_col_loc = loc.dyn_cast<mlir::FileLineColLoc>()) {
519 metadata.set_source_file(file_line_col_loc.getFilename().str());
520 metadata.set_source_line(file_line_col_loc.getLine());
521 }
522
523 return metadata;
524 }
525
526 // Checks if all shardings are set.
AllOptionalShardingsAreSet(llvm::ArrayRef<std::optional<xla::OpSharding>> shardings)527 static bool AllOptionalShardingsAreSet(
528 llvm::ArrayRef<std::optional<xla::OpSharding>> shardings) {
529 return llvm::all_of(shardings,
530 [](const std::optional<xla::OpSharding>& sharding) {
531 return sharding.has_value();
532 });
533 }
534
535 // Extracts argument and result shardings from function.
ExtractShardingsFromFunction(mlir::func::FuncOp function,llvm::SmallVectorImpl<std::optional<xla::OpSharding>> * arg_shardings,llvm::SmallVectorImpl<std::optional<xla::OpSharding>> * ret_shardings)536 static void ExtractShardingsFromFunction(
537 mlir::func::FuncOp function,
538 llvm::SmallVectorImpl<std::optional<xla::OpSharding>>* arg_shardings,
539 llvm::SmallVectorImpl<std::optional<xla::OpSharding>>* ret_shardings) {
540 arg_shardings->resize(function.getNumArguments(),
541 std::optional<xla::OpSharding>());
542 for (int i = 0, end = function.getNumArguments(); i < end; ++i)
543 if (auto sharding =
544 function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
545 (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
546
547 ret_shardings->resize(function.getNumResults(),
548 std::optional<xla::OpSharding>());
549 for (int i = 0, end = function.getNumResults(); i < end; ++i)
550 if (auto sharding =
551 function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
552 (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
553 }
554
555 namespace mlir {
556 namespace {
557 class ConvertToHloModule {
558 public:
559 using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>;
560 using FunctionLoweringMap =
561 llvm::DenseMap<mlir::func::FuncOp, xla::XlaComputation>;
562
563 // If use_tuple_args is true, then the entry function's arguments are
564 // converted to a tuple and passed as a single parameter.
565 // Similarly, if return tuple is true, then the entry function's return values
566 // are converted to a tuple even when there is only a single return value.
567 // Multiple return values are always converted to a tuple and returned as a
568 // single value.
ConvertToHloModule(mlir::ModuleOp module,xla::XlaBuilder & module_builder,bool use_tuple_args,bool return_tuple,MlirToHloConversionOptions options)569 explicit ConvertToHloModule(mlir::ModuleOp module,
570 xla::XlaBuilder& module_builder,
571 bool use_tuple_args, bool return_tuple,
572 MlirToHloConversionOptions options)
573 : module_(module),
574 module_builder_(module_builder),
575 use_tuple_args_(use_tuple_args),
576 return_tuple_(return_tuple),
577 options_(options) {}
578
579 // Perform the lowering to XLA. This function returns failure if an error was
580 // encountered.
581 //
582 // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
Run()583 LogicalResult Run() {
584 auto main = module_.lookupSymbol<mlir::func::FuncOp>("main");
585 if (!main)
586 return module_.emitError(
587 "conversion requires module with `main` function");
588
589 for (auto func : module_.getOps<func::FuncOp>()) {
590 if (func.empty()) continue;
591 if (failed(RunOnFunction(func))) return failure();
592 }
593 return success();
594 }
595
596 // Lower a specific function to HLO.
597 LogicalResult RunOnFunction(mlir::func::FuncOp f);
598
599 // Lower a `mlir::Region` to a `XlaComputation`
600 LogicalResult LowerRegionAsComputation(
601 mlir::Region* region, xla::XlaComputation* func,
602 llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands =
603 llvm::None,
604 bool ensure_single_arg = false);
605
606 // Lower a single `Block` to a `XlaComputation`
607 LogicalResult LowerBasicBlockAsFunction(
608 Block* block, xla::XlaBuilder* builder, bool is_entry_function,
609 bool ensure_single_arg,
610 const std::vector<bool>& entry_args_same_across_replicas,
611 llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
612 llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
613 xla::XlaComputation* result,
614 llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands =
615 llvm::None);
616
ConsumeMainProto()617 ::xla::HloModuleProto ConsumeMainProto() {
618 auto main = module_.lookupSymbol<mlir::func::FuncOp>("main");
619 // This is an invariant check as Run returns failure if there is no main
620 // function and so the main proto shouldn't be consumed in that case.
621 CHECK(main) << "requires module to have main function"; // Crash Ok.
622 return lowered_computation_[main].proto();
623 }
624
625 // Lower function call to HLO call instruction
626 LogicalResult LowerFunctionCall(
627 mlir::func::CallOp call_op, xla::XlaBuilder* builder,
628 ConvertToHloModule::ValueLoweringMap* value_lowering);
629
630 // Look up a symbol with the specified name, returning null if no such name
631 // exists.
LookUpSymbol(FlatSymbolRefAttr symbol)632 func::FuncOp LookUpSymbol(FlatSymbolRefAttr symbol) {
633 return module_.lookupSymbol<mlir::func::FuncOp>(symbol);
634 }
635
636 // Get Reference to lowered XLA computation for a function.
GetLoweredComputation(func::FuncOp func)637 xla::XlaComputation& GetLoweredComputation(func::FuncOp func) {
638 return lowered_computation_[func];
639 }
640
641 LogicalResult Lower(
642 mlir::Operation* inst, bool is_entry_function,
643 llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
644 xla::XlaBuilder* builder,
645 ConvertToHloModule::ValueLoweringMap* value_lowering,
646 xla::XlaOp* return_value);
647
GetOptions() const648 const MlirToHloConversionOptions& GetOptions() const { return options_; }
649
650 private:
651 LogicalResult SetEntryTupleShapesAndLeafReplication(
652 Block* block, const std::vector<bool>& entry_args_same_across_replicas,
653 llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
654 std::vector<bool>* leaf_replication);
655
656 LogicalResult SetEntryTupleShardings(
657 Block* block, xla::XlaBuilder* builder,
658 llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
659 llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
660
661 // The module being lowered.
662 mlir::ModuleOp module_;
663
664 // The top-level XlaBuilder.
665 xla::XlaBuilder& module_builder_;
666
667 // Map between function and lowered computation.
668 FunctionLoweringMap lowered_computation_;
669
670 // Whether the entry function should take a single tuple as input.
671 bool use_tuple_args_;
672
673 // Whether to always return a tuple.
674 bool return_tuple_;
675
676 // Unique suffix to give to the name of the next lowered region.
677 size_t region_id_ = 0;
678
679 MlirToHloConversionOptions options_;
680 };
681
682 } // namespace
683 } // namespace mlir
684
685 namespace {
686
687 struct OpLoweringContext {
688 llvm::DenseMap<mlir::Value, xla::XlaOp>* values;
689 mlir::ConvertToHloModule* converter;
690 xla::XlaBuilder* builder;
691 };
692
GetTuple(mlir::Operation * op,mlir::Operation::operand_range values,OpLoweringContext ctx,llvm::SmallVectorImpl<xla::XlaOp> & results)693 mlir::LogicalResult GetTuple(mlir::Operation* op,
694 mlir::Operation::operand_range values,
695 OpLoweringContext ctx,
696 llvm::SmallVectorImpl<xla::XlaOp>& results) {
697 results.reserve(values.size());
698 for (mlir::Value value : values) {
699 if (failed(GetXlaOp(value, *ctx.values, &results.emplace_back(), op)))
700 return mlir::failure();
701 }
702 return mlir::success();
703 }
704
GetXlaOps(mlir::Operation * op,llvm::ArrayRef<mlir::Value> values,OpLoweringContext ctx,llvm::SmallVectorImpl<xla::XlaOp> & results)705 mlir::LogicalResult GetXlaOps(mlir::Operation* op,
706 llvm::ArrayRef<mlir::Value> values,
707 OpLoweringContext ctx,
708 llvm::SmallVectorImpl<xla::XlaOp>& results) {
709 results.reserve(values.size());
710 for (mlir::Value value : values) {
711 if (failed(GetXlaOp(value, *ctx.values, &results.emplace_back(), op)))
712 return mlir::failure();
713 }
714 return mlir::success();
715 }
716
717 } // namespace
718
719 namespace mlir {
720 namespace mhlo {
721 namespace {
722
ExportXlaOp(ComputeReshapeShapeOp,OpLoweringContext)723 LogicalResult ExportXlaOp(ComputeReshapeShapeOp, OpLoweringContext) {
724 // This op has no expression in the legacy export format. It can be expanded
725 // to a sequence of operations if needed in the future, but would feed into
726 // ops creating unsupported dynamic shapes.
727 return failure();
728 }
729
ExportXlaOp(CstrReshapableOp,OpLoweringContext)730 LogicalResult ExportXlaOp(CstrReshapableOp, OpLoweringContext) {
731 // This op has no expression in the legacy export format.
732 return failure();
733 }
734
ExportXlaOp(AddDependencyOp op,OpLoweringContext ctx)735 LogicalResult ExportXlaOp(AddDependencyOp op, OpLoweringContext ctx) {
736 auto& value_map = *ctx.values;
737 xla::XlaOp token;
738 xla::XlaOp operand;
739 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
740 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
741 auto operand_shape = ctx.builder->GetShape(operand).value();
742 value_map[op] = xla::internal::XlaBuilderFriend::BuildAddDependency(
743 ctx.builder, operand, token, operand_shape);
744 return success();
745 }
746
ExportXlaOp(AllGatherOp op,OpLoweringContext ctx)747 LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) {
748 auto& value_map = *ctx.values;
749 xla::XlaOp operand;
750 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
751 TensorType operand_type = op.operand().getType().cast<TensorType>();
752 TensorType result_type = op.getType();
753 if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
754 return failure();
755 auto all_gather_dim = op.all_gather_dim();
756 int64_t shard_count = result_type.getDimSize(all_gather_dim) /
757 operand_type.getDimSize(all_gather_dim);
758 value_map[op] = xla::AllGather(operand, all_gather_dim, shard_count,
759 Convert_replica_groups(op.replica_groups()),
760 Convert_channel_handle(op.channel_handle()));
761 return success();
762 }
763
ExportXlaOp(AllReduceOp op,OpLoweringContext ctx)764 LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
765 auto& value_map = *ctx.values;
766 xla::XlaComputation computation;
767 if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
768 &computation))) {
769 return failure();
770 }
771
772 xla::XlaOp operand;
773 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
774
775 value_map[op] = xla::AllReduce(
776 operand, computation, Convert_replica_groups(op.replica_groups()),
777 Convert_channel_handle(op.channel_handle()), std::nullopt,
778 Convert_use_global_device_ids(op.use_global_device_ids()));
779 return success();
780 }
781
ExportXlaOp(ReduceScatterOp op,OpLoweringContext ctx)782 LogicalResult ExportXlaOp(ReduceScatterOp op, OpLoweringContext ctx) {
783 auto& value_map = *ctx.values;
784 xla::XlaOp operand;
785 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
786 TensorType operand_type = op.operand().getType().cast<TensorType>();
787 TensorType result_type = op.getType();
788 if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
789 return failure();
790 auto scatter_dim = op.scatter_dimension();
791 int64_t shard_count = operand_type.getDimSize(scatter_dim) /
792 result_type.getDimSize(scatter_dim);
793
794 xla::XlaComputation computation;
795 if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
796 &computation))) {
797 return failure();
798 }
799
800 value_map[op] =
801 xla::ReduceScatter(operand, computation, scatter_dim, shard_count,
802 Convert_replica_groups(op.replica_groups()),
803 Convert_channel_handle(op.channel_handle()));
804 return success();
805 }
806
ExportXlaOp(BitcastConvertOp op,OpLoweringContext ctx)807 LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
808 auto& value_map = *ctx.values;
809 xla::XlaOp operand;
810 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
811
812 value_map[op] = xla::BitcastConvertType(
813 operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
814 return success();
815 }
816
ExportXlaOp(BroadcastInDimOp op,OpLoweringContext ctx)817 LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
818 auto type = op.getType().dyn_cast<RankedTensorType>();
819 if (!type) return failure();
820 auto& value_map = *ctx.values;
821 xla::XlaOp operand;
822 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
823
824 value_map[op] =
825 BroadcastInDim(operand, Convert_ArrayRef(type.getShape()),
826 Convert_broadcast_dimensions(op.broadcast_dimensions()));
827 return success();
828 }
829
ExportXlaOp(CosineOp op,OpLoweringContext ctx)830 LogicalResult ExportXlaOp(CosineOp op, OpLoweringContext ctx) {
831 auto& value_map = *ctx.values;
832 auto result = op.getResult();
833 xla::XlaOp arg;
834 if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op)))
835 return mlir::failure();
836 auto xla_result = xla::Cos(Unwrap(arg));
837 value_map[result] = xla_result;
838 return mlir::success();
839 }
840
ExportXlaOp(DotOp op,OpLoweringContext ctx)841 LogicalResult ExportXlaOp(DotOp op, OpLoweringContext ctx) {
842 auto& value_map = *ctx.values;
843 xla::XlaOp lhs, rhs;
844 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
845 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
846 xla::PrimitiveType preferred_element_type =
847 xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
848 value_map[op] = xla::Dot(
849 lhs, rhs, Unwrap(Convert_precision_config(op.precision_config())),
850 preferred_element_type);
851 return mlir::success();
852 }
853
ExportXlaOp(DotGeneralOp op,OpLoweringContext ctx)854 LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) {
855 auto& value_map = *ctx.values;
856 xla::XlaOp lhs, rhs;
857 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
858 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
859 xla::PrimitiveType preferred_element_type =
860 xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
861 value_map[op] = xla::DotGeneral(
862 lhs, rhs, Convert_dot_dimension_numbers(op.dot_dimension_numbers()),
863 Unwrap(Convert_precision_config(op.precision_config())),
864 preferred_element_type);
865 return mlir::success();
866 }
867
ExportXlaOp(DomainOp op,OpLoweringContext ctx)868 LogicalResult ExportXlaOp(DomainOp op, OpLoweringContext ctx) {
869 auto& valueMap = *ctx.values;
870
871 xla::Shape shape = xla::TypeToShape(op.getResult().getType());
872 xla::XlaOp operand;
873 if (failed(GetXlaOp(op.operand(), valueMap, &operand, op))) return failure();
874
875 auto entry = CreateOpShardingFromStringRef(op.entry_metadata());
876 if (!entry) return failure();
877 auto exit = CreateOpShardingFromStringRef(op.exit_metadata());
878 if (!exit) return failure();
879
880 valueMap[op] = xla::internal::XlaBuilderFriend::BuildDomain(
881 ctx.builder, operand, *exit, *entry, shape);
882 return success();
883 }
884
ExportXlaOp(DynamicBroadcastInDimOp op,OpLoweringContext ctx)885 LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) {
886 // This op has no expression in the legacy export format.
887 return failure();
888 }
889
ExportXlaOp(DynamicIotaOp op,OpLoweringContext ctx)890 LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) {
891 // This op has no expression in the legacy export format.
892 return failure();
893 }
894
ExportXlaOp(DynamicReshapeOp op,OpLoweringContext ctx)895 LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) {
896 // This op has no expression in the legacy export format.
897 return failure();
898 }
899
ExportXlaOp(IfOp op,OpLoweringContext ctx)900 LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
901 xla::XlaComputation true_branch;
902 xla::XlaComputation false_branch;
903 auto& value_map = *ctx.values;
904
905 // mhlo.IfOp does not have any operands or blocks-arguments. The computation
906 // inside the region-blocks use implicit captures of values defined above.
907 // In order to create the xla parameters for functions corresponding to
908 // IfOp regions, we need to infer the a region-block's arguments, using all
909 // the values used in the region but defined above. Note that in case there
910 // are zero implicit capture for a region, we use an empty tuple as the xla
911 // parameter.
912 //
913 // Note that the implicit values used in true and false branch regions might
914 // be different and, as a result, the xla parameters for the corresponding
915 // regions could have different shapes.
916 llvm::SetVector<mlir::Value> implicit_true_operand_set,
917 implicit_false_operand_set;
918 getUsedValuesDefinedAbove(op.true_branch(), op.true_branch(),
919 implicit_true_operand_set);
920 getUsedValuesDefinedAbove(op.false_branch(), op.false_branch(),
921 implicit_false_operand_set);
922
923 llvm::SmallVector<mlir::Value> implicit_true_operands(
924 implicit_true_operand_set.begin(), implicit_true_operand_set.end());
925 llvm::SmallVector<mlir::Value> implicit_false_operands(
926 implicit_false_operand_set.begin(), implicit_false_operand_set.end());
927
928 // Create xla parameters for functions corresponding to ifOp regions using the
929 // implicit captures operands. Also export the instructions within those
930 // regions.
931 if (failed(ctx.converter->LowerRegionAsComputation(
932 &op.true_branch(), &true_branch,
933 llvm::makeArrayRef(implicit_true_operands),
934 /*ensure_single_arg*/ true)) ||
935 failed(ctx.converter->LowerRegionAsComputation(
936 &op.false_branch(), &false_branch,
937 llvm::makeArrayRef(implicit_false_operands),
938 /*ensure_single_arg*/ true))) {
939 return failure();
940 }
941
942 // Create the Xla pred argument.
943 xla::XlaOp pred;
944 if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure();
945
946 // Create the true branch Xla argument.
947 llvm::SmallVector<xla::XlaOp> true_args;
948 if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args)))
949 return failure();
950 xla::XlaOp true_arg =
951 true_args.size() == 1 ? true_args[0] : Tuple(ctx.builder, true_args);
952
953 // Create the false branch Xla argument.
954 llvm::SmallVector<xla::XlaOp> false_args;
955 if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args)))
956 return failure();
957 xla::XlaOp false_arg =
958 false_args.size() == 1 ? false_args[0] : Tuple(ctx.builder, false_args);
959
960 // Create XLA Conditional op.
961 auto ifop =
962 xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch);
963
964 // mhlo.IfOp have multiple returns, untuple all the results of XLA's.
965 if (op.getNumResults() == 1) {
966 value_map[op.getResult(0)] = ifop;
967 } else {
968 for (const auto& item : llvm::enumerate(op.getResults())) {
969 value_map[item.value()] = xla::GetTupleElement(ifop, item.index());
970 }
971 }
972
973 return success();
974 }
975
ExportXlaOp(CaseOp op,OpLoweringContext ctx)976 LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
977 llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
978 // OperandRange operands = op.branch_operands();
979 MutableArrayRef<Region> branches = op.branches();
980 llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
981 std::vector<xla::XlaComputation> computations(branches.size());
982 std::vector<xla::XlaComputation*> computations_p(branches.size());
983
984 // mhlo.CaseOp does not have any operands or blocks-arguments. The computation
985 // inside the region-blocks use implicit captures of values defined above.
986 // In order to create the xla parameters for functions corresponding to
987 // CaseOp regions, we need to infer the a region-block's arguments, using all
988 // the values used in the region but defined above. Note that in case there
989 // are zero implicit captures for a region, we use an empty tuple as the xla
990 // parameter.
991 //
992 // Note that the implicit values used in the regions might
993 // be different and, as a result, the xla parameters for the corresponding
994 // regions could have different shapes.
995 for (unsigned i = 0; i < branches.size(); ++i) {
996 llvm::SetVector<mlir::Value> implicit_operand_set;
997 getUsedValuesDefinedAbove(branches[i], branches[i], implicit_operand_set);
998 llvm::SmallVector<mlir::Value> implicit_operands(
999 implicit_operand_set.begin(), implicit_operand_set.end());
1000
1001 // Create the branches[i]'s Xla argument.
1002 llvm::SmallVector<xla::XlaOp> args;
1003 if (failed(GetXlaOps(op, implicit_operands, ctx, args))) return failure();
1004 branch_operands[i] = args.size() == 1 ? args[0] : Tuple(ctx.builder, args);
1005
1006 // Create xla parameters for functions corresponding to region branches[i]
1007 // using the implicit captures operands. Also export the instructions within
1008 // that region.
1009 computations_p[i] = &computations[i];
1010 if (failed(ctx.converter->LowerRegionAsComputation(
1011 &branches[i], computations_p[i],
1012 llvm::makeArrayRef(implicit_operands),
1013 /*ensure_single_arg*/ true)))
1014 return failure();
1015 }
1016
1017 xla::XlaOp index;
1018 if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure();
1019
1020 xla::XlaOp caseop = xla::Conditional(index, computations_p, branch_operands);
1021
1022 // mhlo.CaseOp have multiple returns, untuple all the results of XLA's.
1023 if (op.getNumResults() == 1) {
1024 value_map[op.getResult(0)] = caseop;
1025 } else {
1026 for (const auto& item : llvm::enumerate(op.getResults())) {
1027 value_map[item.value()] = xla::GetTupleElement(caseop, item.index());
1028 }
1029 }
1030 return success();
1031 }
1032
1033 // Specialize CompareOp export to set broadcast_dimensions argument.
ExportXlaOp(mlir::mhlo::CompareOp op,OpLoweringContext ctx)1034 mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op,
1035 OpLoweringContext ctx) {
1036 auto& value_map = *ctx.values;
1037 xla::XlaOp lhs, rhs;
1038 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
1039 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
1040 auto dir = Convert_comparison_direction(
1041 mlir::mhlo::stringifyComparisonDirection(op.comparison_direction()));
1042 auto type_attr = op.compare_typeAttr();
1043
1044 xla::XlaOp xla_result;
1045 if (type_attr && type_attr.getValue() != mlir::mhlo::ComparisonType::NOTYPE) {
1046 auto type = xla::StringToComparisonType(
1047 stringifyComparisonType(type_attr.getValue()).str())
1048 .ValueOrDie();
1049 xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type);
1050 } else {
1051 xla_result = xla::Compare(lhs, rhs, dir);
1052 }
1053 value_map[op] = xla_result;
1054 return mlir::success();
1055 }
1056
ExportXlaOp(ConstantOp op,OpLoweringContext ctx)1057 LogicalResult ExportXlaOp(ConstantOp op, OpLoweringContext ctx) {
1058 return failure();
1059 }
1060
ExportXlaOp(mlir::mhlo::ConvolutionOp op,OpLoweringContext ctx)1061 LogicalResult ExportXlaOp(mlir::mhlo::ConvolutionOp op, OpLoweringContext ctx) {
1062 auto& value_map = *ctx.values;
1063 xla::XlaOp lhs, rhs;
1064 if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
1065 if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
1066 xla::PrimitiveType preferred_element_type =
1067 xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
1068 xla::XlaOp xla_result = xla::ConvGeneralDilated(
1069 lhs, rhs, Convert_window_strides(op.window_strides()),
1070 Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
1071 Convert_rhs_dilation(op.rhs_dilation()),
1072 xla::ConvertConvDimensionNumbers(op.dimension_numbers()),
1073 Convertuint64_t(op.feature_group_count()),
1074 Convertuint64_t(op.batch_group_count()),
1075 Unwrap(Convert_precision_config(op.precision_config())),
1076 preferred_element_type, Convert_window_reversal(op.window_reversal()));
1077 value_map[op] = xla_result;
1078 return mlir::success();
1079 }
1080
ExportXlaOp(ConvertOp op,OpLoweringContext ctx)1081 LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
1082 auto& value_map = *ctx.values;
1083 xla::XlaOp operand;
1084 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1085
1086 value_map[op] = xla::ConvertElementType(
1087 operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
1088 return success();
1089 }
1090
ExportXlaOp(CustomCallOp op,OpLoweringContext ctx)1091 LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
1092 if (op.getNumResults() != 1)
1093 return op.emitOpError() << "with multiple results cannot be exported";
1094
1095 if (op.called_computations().size() > 1)
1096 return op.emitOpError()
1097 << "cannot export with more than one called computations";
1098
1099 // Custom call can be exported either with called computation or with layout
1100 // attributes. The XlaBuilder API does not allow both.
1101 if (!op.called_computations().empty() && op.operand_layouts() &&
1102 op.result_layouts()) {
1103 return op.emitOpError() << "cannot export if both called computation and "
1104 "layouts are specified";
1105 }
1106
1107 Value result = op.getResult(0);
1108 llvm::SmallVector<xla::XlaOp> args;
1109 if (failed(GetTuple(op, op.operands(), ctx, args))) return failure();
1110 auto xla_api_version = xla::ConvertCustomCallApiVersion(op.api_version());
1111 if (!xla_api_version.ok()) return failure();
1112 auto& value_map = *ctx.values;
1113
1114 if (op.called_computations().size() == 1) {
1115 mlir::func::FuncOp callee = ctx.converter->LookUpSymbol(
1116 op.called_computations()[0].cast<FlatSymbolRefAttr>());
1117 if (failed(ctx.converter->RunOnFunction(callee))) return failure();
1118 xla::XlaComputation& computation =
1119 ctx.converter->GetLoweredComputation(callee);
1120 value_map[result] = xla::CustomCallWithComputation(
1121 ctx.builder, std::string(op.call_target_name()), args, computation,
1122 xla::TypeToShape(result.getType()), std::string(op.backend_config()),
1123 op.has_side_effect(),
1124 /*output_operand_aliasing=*/{},
1125 /*literal=*/nullptr,
1126 /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
1127 /*api_version=*/*xla_api_version);
1128 return success();
1129 }
1130
1131 if (op.operand_layouts() && op.result_layouts()) {
1132 auto operand_shapes_with_layout = ConvertTypesToShapesWithLayout(
1133 op.getOperandTypes(), op.operand_layouts().getValue());
1134 xla::Shape result_shape_with_layout = GetCustomCallResultShapeWithLayout(
1135 result.getType(), op.result_layouts().getValue());
1136 value_map[result] = xla::CustomCallWithLayout(
1137 ctx.builder, std::string(op.call_target_name()), args,
1138 result_shape_with_layout, operand_shapes_with_layout,
1139 std::string(op.backend_config()), op.has_side_effect(),
1140 /*output_operand_aliasing=*/{},
1141 /*literal=*/nullptr,
1142 /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
1143 /*api_version=*/*xla_api_version);
1144 return success();
1145 }
1146
1147 value_map[result] = xla::CustomCall(
1148 ctx.builder, std::string(op.call_target_name()), args,
1149 xla::TypeToShape(result.getType()), std::string(op.backend_config()),
1150 op.has_side_effect(), /*output_operand_aliasing=*/{},
1151 /*literal=*/nullptr,
1152 /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
1153 /*api_version=*/*xla_api_version);
1154 return success();
1155 }
1156
ExportXlaOp(InfeedOp op,OpLoweringContext ctx)1157 LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
1158 auto& value_map = *ctx.values;
1159 xla::XlaOp token;
1160 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1161
1162 // mhlo.infeed produces multiple results. The shape argument expected by the
1163 // xla client API is a tuple type with two element-types:
1164 // data_type : A tuple containing all the mhlo.infeedOp result types except
1165 // the token type.
1166 // token_type : The last result type of mhlo.infeedOp.
1167 auto result_types = op.getResultTypes();
1168 auto num_results = op.getNumResults();
1169
1170 xla::Shape token_shape = xla::TypeToShape(result_types[num_results - 1]);
1171 std::vector<xla::Shape> subshapes;
1172 for (const auto& item : llvm::enumerate(result_types)) {
1173 if (item.index() == num_results - 1) break;
1174 subshapes.push_back(xla::TypeToShape(item.value()));
1175 }
1176
1177 xla::Shape data_shape = xla::ShapeUtil::MakeTupleShape(subshapes);
1178 auto xla_result =
1179 xla::InfeedWithToken(token, data_shape, std::string(op.infeed_config()));
1180 ctx.builder->ClearSharding();
1181
1182 if (!subshapes.empty()) {
1183 auto data_tuple_element = xla::GetTupleElement(xla_result, 0);
1184 for (const auto& item : llvm::enumerate(op.getResults())) {
1185 if (item.index() == num_results - 1) break;
1186 value_map[item.value()] =
1187 xla::GetTupleElement(data_tuple_element, item.index());
1188 }
1189 }
1190
1191 value_map[op.getResult(num_results - 1)] =
1192 xla::GetTupleElement(xla_result, 1);
1193
1194 return success();
1195 }
1196
ExportXlaOp(IotaOp op,OpLoweringContext ctx)1197 LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
1198 auto& value_map = *ctx.values;
1199 value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
1200 op.iota_dimension());
1201 return success();
1202 }
1203
ExportXlaOp(MapOp op,OpLoweringContext ctx)1204 LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) {
1205 auto& value_map = *ctx.values;
1206 xla::XlaComputation computation;
1207 if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
1208 &computation))) {
1209 return failure();
1210 }
1211 llvm::SmallVector<xla::XlaOp> operands;
1212 if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1213 value_map[op] = xla::Map(ctx.builder, operands, computation,
1214 Convert_dimensions(op.dimensions()));
1215 return success();
1216 }
1217
ExportXlaOp(OutfeedOp op,OpLoweringContext ctx)1218 LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
1219 auto& value_map = *ctx.values;
1220
1221 llvm::SmallVector<xla::XlaOp> operands;
1222 if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1223
1224 xla::XlaOp operand = Tuple(ctx.builder, operands);
1225
1226 std::vector<xla::Shape> subshapes;
1227 for (auto operand : op.operands())
1228 subshapes.push_back(xla::TypeToShape(operand.getType()));
1229
1230 xla::Shape shape_with_layout = xla::ShapeUtil::MakeTupleShape(subshapes);
1231
1232 xla::XlaOp token;
1233 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1234
1235 value_map[op] = xla::OutfeedWithToken(operand, token, shape_with_layout,
1236 std::string(op.outfeed_config()));
1237 return success();
1238 }
1239
ExportXlaOp(PartitionIdOp op,OpLoweringContext ctx)1240 LogicalResult ExportXlaOp(PartitionIdOp op, OpLoweringContext ctx) {
1241 auto& value_map = *ctx.values;
1242 xla::Shape shape = xla::TypeToShape(op.getResult().getType());
1243 value_map[op] =
1244 xla::internal::XlaBuilderFriend::BuildPartitionId(ctx.builder, shape);
1245 return success();
1246 }
1247
ExportXlaOp(PadOp op,OpLoweringContext ctx)1248 LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
1249 auto& value_map = *ctx.values;
1250 xla::PaddingConfig padding_config;
1251 auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
1252 auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
1253 auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
1254 for (int64_t i = 0, end = edge_padding_low.size(); i < end; ++i) {
1255 auto* dims = padding_config.add_dimensions();
1256 dims->set_edge_padding_low(edge_padding_low[i]);
1257 dims->set_edge_padding_high(edge_padding_high[i]);
1258 dims->set_interior_padding(interior_padding[i]);
1259 }
1260 xla::XlaOp operand, padding_value;
1261 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1262 if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op)))
1263 return failure();
1264
1265 value_map[op] = xla::Pad(operand, padding_value, padding_config);
1266 return success();
1267 }
1268
ExportXlaOp(RecvOp op,OpLoweringContext ctx)1269 LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) {
1270 auto& value_map = *ctx.values;
1271
1272 xla::XlaOp token;
1273 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1274
1275 // mhlo.recvOp produces multiple results. The shape argument expected by the
1276 // xla client API is a tuple type with two element-types:
1277 // data_type : A tuple containing all the mhlo.RecvOp result types except
1278 // the token type.
1279 // token_type : The last result type of mhlo.recvOp.
1280 auto result_types = op.getResultTypes();
1281 auto num_results = op.getNumResults();
1282
1283 xla::Shape token_shape = xla::TypeToShape(result_types[num_results - 1]);
1284 std::vector<xla::Shape> subshapes;
1285 for (const auto& item : llvm::enumerate(result_types)) {
1286 if (item.index() == num_results - 1) break;
1287 subshapes.push_back(xla::TypeToShape(item.value()));
1288 }
1289
1290 xla::Shape data_shape;
1291 if (subshapes.size() == 1)
1292 data_shape = subshapes[0];
1293 else
1294 data_shape = xla::ShapeUtil::MakeTupleShape(subshapes);
1295
1296 xla::XlaOp xla_result;
1297 if (op.is_host_transfer()) {
1298 xla_result = xla::RecvFromHost(token, data_shape,
1299 Convert_channel_handle(op.channel_handle()));
1300 } else {
1301 xla_result = xla::RecvWithToken(
1302 token, data_shape, Convert_channel_handle(op.channel_handle()));
1303 }
1304
1305 auto data_tuple_element = xla::GetTupleElement(xla_result, 0);
1306 if (subshapes.size() == 1) {
1307 value_map[op.getResult(0)] = data_tuple_element;
1308 } else {
1309 for (const auto& item : llvm::enumerate(op.getResults())) {
1310 if (item.index() == num_results - 1) break;
1311 value_map[item.value()] =
1312 xla::GetTupleElement(data_tuple_element, item.index());
1313 }
1314 }
1315
1316 value_map[op.getResult(num_results - 1)] =
1317 xla::GetTupleElement(xla_result, 1);
1318
1319 return success();
1320 }
1321
ExportXlaOp(ReduceOp op,OpLoweringContext ctx)1322 LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
1323 auto& value_map = *ctx.values;
1324 xla::XlaComputation body;
1325 if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
1326 return failure();
1327 }
1328 llvm::SmallVector<xla::XlaOp> operands, init_values;
1329 if (failed(GetTuple(op, op.operands(), ctx, operands)) ||
1330 failed(GetTuple(op, op.init_values(), ctx, init_values))) {
1331 return failure();
1332 }
1333 xla::XlaOp result =
1334 xla::Reduce(ctx.builder, operands, init_values, body,
1335 Convert_broadcast_dimensions(op.dimensions()));
1336 if (op.getNumResults() == 1) {
1337 value_map[op.getResult(0)] = result;
1338 } else {
1339 for (const auto& item : llvm::enumerate(op.getResults())) {
1340 value_map[item.value()] = xla::GetTupleElement(result, item.index());
1341 }
1342 }
1343 return success();
1344 }
1345
ExportXlaOp(ReduceWindowOp op,OpLoweringContext ctx)1346 LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
1347 auto& value_map = *ctx.values;
1348 xla::XlaComputation body;
1349 if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
1350 return failure();
1351 }
1352 llvm::SmallVector<xla::XlaOp> operands, init_values;
1353 if (failed(GetTuple(op, op.operands(), ctx, operands)) ||
1354 failed(GetTuple(op, op.init_values(), ctx, init_values))) {
1355 return failure();
1356 }
1357
1358 xla::XlaOp result = xla::ReduceWindowWithGeneralPadding(
1359 operands, init_values, body, ConvertDenseIntAttr(op.window_dimensions()),
1360 ConvertDenseIntAttr(op.window_strides()),
1361 ConvertDenseIntAttr(op.base_dilations()),
1362 ConvertDenseIntAttr(op.window_dilations()),
1363 Convert_padding(op.padding()));
1364
1365 if (op.getNumResults() == 1) {
1366 value_map[op.getResult(0)] = result;
1367 } else {
1368 for (const auto& item : llvm::enumerate(op.getResults())) {
1369 value_map[item.value()] = xla::GetTupleElement(result, item.index());
1370 }
1371 }
1372 return success();
1373 }
1374
ExportXlaOp(ReshapeOp op,OpLoweringContext ctx)1375 LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
1376 auto& value_map = *ctx.values;
1377 xla::XlaOp operand;
1378 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1379
1380 value_map[op] =
1381 xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions());
1382 return success();
1383 }
1384
ExportXlaOp(ReturnOp op,OpLoweringContext ctx)1385 LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
1386 // Failure on purpose because `mhlo::ReturnOp` will be handled by
1387 // special purpose logic in `ConvertToHloModule::Lower`.
1388 return failure();
1389 }
1390
ExportXlaOp(RngBitGeneratorOp op,OpLoweringContext ctx)1391 LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) {
1392 auto& value_map = *ctx.values;
1393 auto results = op.getResults();
1394 auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()];
1395 auto xla_result = xla::RngBitGenerator(
1396 static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1),
1397 xla::TypeToShape(results[1].getType()));
1398
1399 for (const auto& item : llvm::enumerate(results))
1400 value_map[item.value()] = xla::GetTupleElement(xla_result, item.index());
1401
1402 return mlir::success();
1403 }
1404
ExportXlaOp(XlaRngGetAndUpdateStateOp op,OpLoweringContext ctx)1405 LogicalResult ExportXlaOp(XlaRngGetAndUpdateStateOp op, OpLoweringContext ctx) {
1406 // This op does not exist in the XLA builder interface.
1407 (*ctx.values)[op.getResult()] =
1408 xla::internal::XlaBuilderFriend::BuildRngGetAndUpdateState(
1409 ctx.builder, static_cast<int64_t>(op.delta()),
1410 xla::TypeToShape(op.getType()));
1411 return mlir::success();
1412 }
1413
ExportXlaOp(BatchNormGradOp op,OpLoweringContext ctx)1414 LogicalResult ExportXlaOp(BatchNormGradOp op, OpLoweringContext ctx) {
1415 auto& value_map = *ctx.values;
1416 auto results = op.getResults();
1417
1418 xla::XlaOp operand, scale, mean, variance, grad_output;
1419 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1420 if (failed(GetXlaOp(op.scale(), value_map, &scale, op))) return failure();
1421 if (failed(GetXlaOp(op.mean(), value_map, &mean, op))) return failure();
1422 if (failed(GetXlaOp(op.variance(), value_map, &variance, op)))
1423 return failure();
1424 if (failed(GetXlaOp(op.grad_output(), value_map, &grad_output, op)))
1425 return failure();
1426
1427 auto xla_result =
1428 xla::BatchNormGrad(operand, scale, mean, variance, grad_output,
1429 ConvertAPFloat(op.epsilon()), op.feature_index());
1430
1431 for (const auto& item : llvm::enumerate(results))
1432 value_map[item.value()] = xla::GetTupleElement(xla_result, item.index());
1433
1434 return mlir::success();
1435 }
1436
ExportXlaOp(BatchNormTrainingOp op,OpLoweringContext ctx)1437 LogicalResult ExportXlaOp(BatchNormTrainingOp op, OpLoweringContext ctx) {
1438 auto& value_map = *ctx.values;
1439 auto results = op.getResults();
1440
1441 xla::XlaOp operand, scale, offset;
1442 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1443 if (failed(GetXlaOp(op.scale(), value_map, &scale, op))) return failure();
1444 if (failed(GetXlaOp(op.offset(), value_map, &offset, op))) return failure();
1445
1446 auto xla_result = xla::BatchNormTraining(
1447 operand, scale, offset, ConvertAPFloat(op.epsilon()), op.feature_index());
1448
1449 for (const auto& item : llvm::enumerate(results))
1450 value_map[item.value()] = xla::GetTupleElement(xla_result, item.index());
1451
1452 return mlir::success();
1453 }
1454
ExportXlaOp(RngOp op,OpLoweringContext ctx)1455 LogicalResult ExportXlaOp(RngOp op, OpLoweringContext ctx) {
1456 auto& value_map = *ctx.values;
1457 xla::XlaOp a, b;
1458 if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure();
1459 if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure();
1460
1461 if (op.rng_distribution() == RngDistribution::UNIFORM) {
1462 value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType()));
1463 return success();
1464 } else if (op.rng_distribution() == RngDistribution::NORMAL) {
1465 value_map[op] = xla::RngNormal(a, b, xla::TypeToShape(op.getType()));
1466 return success();
1467 }
1468 return failure();
1469 }
1470
ExportXlaOp(ScatterOp op,OpLoweringContext ctx)1471 LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
1472 auto& value_map = *ctx.values;
1473 xla::XlaComputation update_computation;
1474 if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
1475 &update_computation))) {
1476 return failure();
1477 }
1478 xla::ScatterDimensionNumbers dimension_numbers =
1479 Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
1480
1481 llvm::SmallVector<xla::XlaOp> operands;
1482 llvm::SmallVector<xla::XlaOp> updates;
1483 if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1484 if (failed(GetTuple(op, op.updates(), ctx, updates))) return failure();
1485
1486 xla::XlaOp scatter_indices;
1487 if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op)))
1488 return failure();
1489
1490 auto scatter_op = xla::Scatter(operands, scatter_indices, updates,
1491 update_computation, dimension_numbers,
1492 op.indices_are_sorted(), op.unique_indices());
1493 if (op->getNumResults() == 1) {
1494 value_map[op.getResult(0)] = scatter_op;
1495 return success();
1496 }
1497
1498 // mhlo.ScatterOp supports multiple returns, untuple all the results of XLA's.
1499 for (const auto& it : llvm::enumerate(op.getResults())) {
1500 value_map[it.value()] = xla::GetTupleElement(scatter_op, it.index());
1501 }
1502
1503 return success();
1504 }
1505
ExportXlaOp(SelectAndScatterOp op,OpLoweringContext ctx)1506 LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
1507 auto& value_map = *ctx.values;
1508 xla::XlaComputation select;
1509 xla::XlaComputation scatter;
1510 if (failed(ctx.converter->LowerRegionAsComputation(&op.select(), &select)) ||
1511 failed(
1512 ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) {
1513 return failure();
1514 }
1515 xla::XlaOp operand, source, init_value;
1516 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1517 if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure();
1518 if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
1519 return failure();
1520
1521 value_map[op] = xla::SelectAndScatterWithGeneralPadding(
1522 operand, select, ConvertDenseIntAttr(op.window_dimensions()),
1523 ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()),
1524 source, init_value, scatter);
1525 return success();
1526 }
1527
ExportXlaOp(SendOp op,OpLoweringContext ctx)1528 LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) {
1529 auto& value_map = *ctx.values;
1530
1531 llvm::SmallVector<xla::XlaOp> operands;
1532 if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1533
1534 xla::XlaOp operand;
1535 if (operands.size() == 1)
1536 operand = operands[0];
1537 else
1538 operand = Tuple(ctx.builder, operands);
1539
1540 xla::XlaOp token;
1541 if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1542
1543 if (op.is_host_transfer()) {
1544 value_map[op] = xla::SendToHost(
1545 operand, token, operand.builder()->GetShape(operand).value(),
1546 Convert_channel_handle(op.channel_handle()));
1547 return success();
1548 }
1549 value_map[op] = xla::SendWithToken(
1550 operand, token, Convert_channel_handle(op.channel_handle()));
1551 return success();
1552 }
1553
ExportXlaOp(mlir::mhlo::SineOp op,OpLoweringContext ctx)1554 mlir::LogicalResult ExportXlaOp(mlir::mhlo::SineOp op, OpLoweringContext ctx) {
1555 auto& value_map = *ctx.values;
1556 auto result = op.getResult();
1557 xla::XlaOp arg;
1558 if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op)))
1559 return mlir::failure();
1560 auto xla_result = xla::Sin(Unwrap(arg));
1561 value_map[result] = xla_result;
1562 return mlir::success();
1563 }
1564
ExportXlaOp(SliceOp op,OpLoweringContext ctx)1565 LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
1566 return failure();
1567 }
1568
ExportXlaOp(SortOp op,OpLoweringContext ctx)1569 LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
1570 xla::XlaComputation comparator;
1571 if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
1572 &comparator)))
1573 return failure();
1574
1575 llvm::SmallVector<xla::XlaOp> operands;
1576 if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1577 auto sorted = xla::Sort(operands, comparator, op.dimension(), op.is_stable());
1578
1579 auto& value_map = *ctx.values;
1580 auto shape_or = sorted.builder()->GetShape(sorted);
1581 if (!shape_or.ok()) {
1582 return op.emitError(shape_or.status().ToString());
1583 }
1584
1585 xla::Shape& shape = shape_or.ValueOrDie();
1586 if (!shape.IsTuple()) {
1587 value_map[op.getResult(0)] = sorted;
1588 return success();
1589 }
1590
1591 // MLIR's sort supports multiple returns, untuple all the results of XLA's.
1592 for (const auto& it : llvm::enumerate(op.getResults())) {
1593 value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
1594 }
1595 return success();
1596 }
1597
ExportXlaOp(SubtractOp op,OpLoweringContext ctx)1598 LogicalResult ExportXlaOp(SubtractOp op, OpLoweringContext ctx) {
1599 auto& value_map = *ctx.values;
1600 auto result = op.getResult();
1601 xla::XlaOp lhs;
1602 if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &lhs, op)))
1603 return mlir::failure();
1604 xla::XlaOp rhs;
1605 if (failed(GetXlaOp(*op.getODSOperands(1).begin(), value_map, &rhs, op)))
1606 return mlir::failure();
1607 auto xla_result = xla::Sub(Unwrap(lhs), Unwrap(rhs));
1608 value_map[result] = xla_result;
1609 return mlir::success();
1610 }
1611
ExportXlaOp(TraceOp op,OpLoweringContext ctx)1612 LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) {
1613 // TODO(atondwal): remove mhlo.trace
1614 return success();
1615 }
1616
ExportXlaOp(UnaryEinsumOp op,OpLoweringContext ctx)1617 LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
1618 // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
1619 // operands.
1620 return failure();
1621 }
1622
ExportXlaOp(WhileOp op,OpLoweringContext ctx)1623 LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
1624 xla::XlaComputation condition;
1625 xla::XlaComputation body;
1626 if (failed(ctx.converter->LowerRegionAsComputation(
1627 &op.body(), &body, llvm::None, /*ensure_single_arg*/ true)) ||
1628 failed(ctx.converter->LowerRegionAsComputation(
1629 &op.cond(), &condition, llvm::None, /*ensure_single_arg*/ true))) {
1630 return failure();
1631 }
1632
1633 // In case MHLO's whileOp has multiple operands, create xla::Tuple, using
1634 // those operands, to be used as sole operand of xla::While.
1635 llvm::SmallVector<xla::XlaOp> operands;
1636 if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure();
1637
1638 xla::XlaOp operand = operands[0];
1639 if (operands.size() > 1) operand = Tuple(ctx.builder, operands);
1640
1641 auto whileop = xla::While(condition, body, operand);
1642
1643 auto& value_map = *ctx.values;
1644 auto shape_or = whileop.builder()->GetShape(whileop);
1645 if (!shape_or.ok()) {
1646 return op.emitError(shape_or.status().ToString());
1647 }
1648
1649 xla::Shape& shape = shape_or.ValueOrDie();
1650 if (!shape.IsTuple()) {
1651 value_map[op.getResult(0)] = whileop;
1652 return success();
1653 }
1654
1655 // mhlo.WhileOp supports multiple returns, untuple all the results of XLA's.
1656 for (const auto& it : llvm::enumerate(op.getResults())) {
1657 value_map[it.value()] = xla::GetTupleElement(whileop, it.index());
1658 }
1659
1660 return success();
1661 }
1662
ExportXlaOp(OptimizationBarrierOp op,OpLoweringContext ctx)1663 LogicalResult ExportXlaOp(OptimizationBarrierOp op, OpLoweringContext ctx) {
1664 // In case MHLO's OptimizationBarrierOp has multiple operands,
1665 // create xla::Tuple, using those operands, to be used as
1666 // sole operand of xla::OptimizationBarrier.
1667 llvm::SmallVector<xla::XlaOp> operands;
1668 if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure();
1669 if (operands.empty()) return success();
1670
1671 auto& value_map = *ctx.values;
1672 if (operands.size() == 1) {
1673 value_map[op.getResult(0)] = xla::OptimizationBarrier(operands[0]);
1674 } else {
1675 auto result = xla::OptimizationBarrier(Tuple(ctx.builder, operands));
1676
1677 for (const auto& it : llvm::enumerate(op.getResults())) {
1678 value_map[it.value()] = xla::GetTupleElement(result, it.index());
1679 }
1680 }
1681
1682 return success();
1683 }
1684
ExportXlaOp(FusionOp op,OpLoweringContext ctx)1685 LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
1686 if (!op.fusion_kind()) {
1687 op.emitOpError() << "requires fusion kind for HLO translation";
1688 return failure();
1689 }
1690
1691 xla::XlaComputation fused_computation;
1692 if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(),
1693 &fused_computation)))
1694 return failure();
1695
1696 auto& values = *ctx.values;
1697 llvm::SmallVector<xla::XlaOp, 4> operands;
1698 for (auto operand : op.operands()) operands.push_back(values[operand]);
1699
1700 auto fusion_kind_string =
1701 mlir::mhlo::stringifyFusionKind(op.fusion_kind().getValue());
1702 xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion(
1703 ctx.builder, operands,
1704 absl::string_view(fusion_kind_string.data(), fusion_kind_string.size()),
1705 fused_computation);
1706 if (op.getNumResults() == 1) {
1707 values[op.getResult(0)] = fusion;
1708 } else {
1709 for (const auto& item : llvm::enumerate(op.getResults())) {
1710 values[item.value()] = xla::GetTupleElement(fusion, item.index());
1711 }
1712 }
1713 return success();
1714 }
1715
ExportXlaOp(BitcastOp op,OpLoweringContext ctx)1716 LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
1717 auto& value_map = *ctx.values;
1718 xla::XlaOp operand;
1719 if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1720 xla::XlaOp bitcast = xla::internal::XlaBuilderFriend::BuildBitcast(
1721 ctx.builder, operand, xla::TypeToShape(op.getType()));
1722 value_map[op] = bitcast;
1723 if (ctx.converter->GetOptions().propagate_bitcast_layouts_to_backend_config) {
1724 // Encode the source and result layout of the bitcast into the XLA HLO
1725 // backend config as a protobuf. Note that this is a temporary solution
1726 // which will go away once XLA:GPU stops falling back to XLA HLO Elemental
1727 // IR emitters.
1728 xla::HloInstructionProto* bitcast_proto =
1729 xla::internal::XlaBuilderFriend::GetInstruction(bitcast);
1730 xla::HloInstructionProto* operand_proto =
1731 xla::internal::XlaBuilderFriend::GetInstruction(operand);
1732 xla::LayoutProto result_layout =
1733 ExtractLayout(op, bitcast_proto->shape().dimensions_size(),
1734 "result_layout")
1735 .ToProto();
1736 xla::LayoutProto source_layout =
1737 ExtractLayout(op, operand_proto->shape().dimensions_size(),
1738 "source_layout")
1739 .ToProto();
1740 xla::gpu::BitcastBackendConfig bitcast_config;
1741 *bitcast_config.mutable_source_layout() = source_layout;
1742 *bitcast_config.mutable_result_layout() = result_layout;
1743 *bitcast_proto->mutable_backend_config() =
1744 bitcast_config.SerializeAsString();
1745 }
1746 return success();
1747 }
1748
ExportXlaOp(RealDynamicSliceOp op,OpLoweringContext ctx)1749 LogicalResult ExportXlaOp(RealDynamicSliceOp op, OpLoweringContext ctx) {
1750 return failure();
1751 }
1752
ExportXlaOp(DynamicPadOp op,OpLoweringContext ctx)1753 LogicalResult ExportXlaOp(DynamicPadOp op, OpLoweringContext ctx) {
1754 return failure();
1755 }
1756
ExportXlaOp(DynamicGatherOp op,OpLoweringContext ctx)1757 LogicalResult ExportXlaOp(DynamicGatherOp op, OpLoweringContext ctx) {
1758 return failure();
1759 }
1760
ExportXlaOp(DynamicConvOp op,OpLoweringContext ctx)1761 LogicalResult ExportXlaOp(DynamicConvOp op, OpLoweringContext ctx) {
1762 return failure();
1763 }
1764
ExportXlaOp(UniformQuantizeOp op,OpLoweringContext ctx)1765 LogicalResult ExportXlaOp(UniformQuantizeOp op, OpLoweringContext ctx) {
1766 // Currently, it doesn't have an XLA builder equivalent.
1767 // TODO(b/230671877): Implement XLA import/export for quantized MHLO ops.
1768 return failure();
1769 }
1770
ExportXlaOp(UniformDequantizeOp op,OpLoweringContext ctx)1771 LogicalResult ExportXlaOp(UniformDequantizeOp op, OpLoweringContext ctx) {
1772 // Currently, it doesn't have an XLA builder equivalent.
1773 // TODO(b/230671877): Implement XLA import/export for quantized MHLO ops.
1774 return failure();
1775 }
1776
1777 } // namespace
1778 } // namespace mhlo
1779 } // namespace mlir
1780
1781 #include "tensorflow/compiler/mlir/xla/operator_writers.inc"
1782
1783 namespace mlir {
1784 namespace {
1785
CreateArrayLiteralFromAttr(ElementsAttr attr,xla::Layout layout)1786 StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
1787 xla::Layout layout) {
1788 auto dense_attr = attr.dyn_cast<DenseElementsAttr>();
1789 if (!dense_attr)
1790 return tensorflow::errors::Unimplemented(
1791 "Only dense elements attr are supported");
1792
1793 xla::Shape shape = xla::TypeToShape(dense_attr.getType());
1794
1795 #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type) \
1796 case xla_type: { \
1797 xla::Array<cpp_type> source_data(shape.dimensions()); \
1798 source_data.SetValues(dense_attr.getValues<cpp_type>()); \
1799 return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
1800 }
1801
1802 switch (shape.element_type()) {
1803 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool)
1804 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float)
1805 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double)
1806 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8)
1807 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
1808 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
1809 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64_t)
1810 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
1811 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
1812 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
1813 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64)
1814 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>)
1815 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>)
1816 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F16, Eigen::half)
1817 ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::BF16, Eigen::bfloat16)
1818 default:
1819 return tensorflow::errors::Internal(absl::StrCat(
1820 "Unsupported type: ", xla::PrimitiveType_Name(shape.element_type())));
1821 }
1822 #undef ELEMENTS_ATTR_TO_LITERAL
1823 }
1824
ConvertLayout(mlir::Operation * op,const mlir::ArrayAttr & layout,xla::ShapeProto * shape)1825 LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout,
1826 xla::ShapeProto* shape) {
1827 // In the case of tuples, ShapeProtos can be nested, and so can the mlir
1828 // attribute describing the layout. So recurse into the subshapes in both data
1829 // structures in parallel.
1830 if (shape->element_type() == xla::TUPLE) {
1831 auto subshapes = shape->mutable_tuple_shapes();
1832
1833 // 'layout' does not take the token attribute into account, so skip the
1834 // corresponding entry from xla shape proto.
1835 size_t subshapes_data_size = subshapes->size();
1836 if (!subshapes->empty() &&
1837 subshapes->Mutable(subshapes->size() - 1)->element_type() == xla::TOKEN)
1838 subshapes_data_size = subshapes->size() - 1;
1839
1840 if (layout.size() != subshapes_data_size) {
1841 op->emitOpError() << "Expected layout of size " << layout.size()
1842 << ", but found " << subshapes->size();
1843 return failure();
1844 }
1845 for (int i = 0; i < subshapes_data_size; i++) {
1846 mlir::Attribute child = layout[i];
1847 if (child.isa<mlir::UnitAttr>()) {
1848 // ignore unit attributes, they are used only for tokens.
1849 continue;
1850 }
1851 mlir::ArrayAttr c = child.dyn_cast<mlir::ArrayAttr>();
1852 if (!c) {
1853 op->emitOpError() << "Type Error: Expected layout array attribute";
1854 return failure();
1855 }
1856 if (failed(ConvertLayout(op, c, subshapes->Mutable(i)))) {
1857 return failure();
1858 }
1859 }
1860 } else {
1861 int rank = shape->dimensions().size();
1862 if (rank) {
1863 if (layout.size() != rank) {
1864 return failure(); // pass error down
1865 }
1866 std::vector<int64_t> array(rank);
1867 for (int i = 0; i < rank; i++) {
1868 mlir::IntegerAttr attr = layout[i].dyn_cast<mlir::IntegerAttr>();
1869 if (!attr) {
1870 op->emitOpError() << "Type Error: Expected layout integer attribute";
1871 return failure();
1872 }
1873 array[i] = attr.getInt();
1874 }
1875 *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto();
1876 }
1877 }
1878 return success();
1879 }
1880
1881 // Assigns layouts from 'layout' to shape.
1882 // The function accepts any of the following shapes
1883 // one or more array-shape(s) of infeed data
1884 // Tuple(Tuple(zero or more array-shape w.r.t data), token_type)
1885 //
1886 // 'layout' of the mhlo.InfedOp 'op' is
1887 // [zero or more layout for each array-shape w.r.t data]
1888 // 'layout_index' indexes into 'layout' accessing a layout corresponding to a
1889 // shape.
ConvertInfeedtLayout(mlir::Operation * op,const mlir::ArrayAttr & layout,xla::ShapeProto * shape,int64_t layout_index=0)1890 LogicalResult ConvertInfeedtLayout(mlir::Operation* op,
1891 const mlir::ArrayAttr& layout,
1892 xla::ShapeProto* shape,
1893 int64_t layout_index = 0) {
1894 if (shape->element_type() != xla::TUPLE) {
1895 // Handles following shape:
1896 // single array-shape of infeed data
1897 mlir::ArrayAttr child_layout =
1898 layout[layout_index].dyn_cast<mlir::ArrayAttr>();
1899 if (!child_layout) {
1900 op->emitOpError() << "Type Error: Expected layout array attribute";
1901 return failure();
1902 }
1903
1904 int rank = shape->dimensions().size();
1905 if (rank) {
1906 if (child_layout.size() != rank) {
1907 return failure(); // pass error down
1908 }
1909 std::vector<int64_t> array(rank);
1910 for (int i = 0; i < rank; i++) {
1911 mlir::IntegerAttr attr = child_layout[i].dyn_cast<mlir::IntegerAttr>();
1912 if (!attr) {
1913 op->emitOpError() << "Type Error: Expected layout integer attribute";
1914 return failure();
1915 }
1916 array[i] = attr.getInt();
1917 }
1918 *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto();
1919 }
1920
1921 return success();
1922 }
1923
1924 auto subshapes = shape->mutable_tuple_shapes();
1925 auto datashape = subshapes->Mutable(0);
1926
1927 if (datashape->element_type() == xla::TUPLE) {
1928 // Handles following shapes:
1929 // (Tuple(zero or more array-shape w.r.t data), token_type)
1930 auto data_subshapes = datashape->mutable_tuple_shapes();
1931 if (layout.size() != data_subshapes->size()) {
1932 op->emitOpError() << "Expected " << data_subshapes->size()
1933 << " layout attribute(s) for infeed data, but found "
1934 << layout.size();
1935 return failure();
1936 }
1937
1938 for (int i = 0; i < data_subshapes->size(); i++) {
1939 if (failed(
1940 ConvertInfeedtLayout(op, layout, data_subshapes->Mutable(i), i)))
1941 return failure();
1942 }
1943 } else {
1944 // Handles following shapes:
1945 // array-shapes of two or more infeed data
1946 if (layout.size() != subshapes->size()) {
1947 op->emitOpError() << "Expected " << subshapes->size()
1948 << " layout attribute(s) for infeed data, but found "
1949 << layout.size();
1950 return failure();
1951 }
1952
1953 for (int i = 0; i < subshapes->size(); i++) {
1954 if (failed(ConvertInfeedtLayout(op, layout, subshapes->Mutable(i), i)))
1955 return failure();
1956 }
1957 }
1958
1959 return success();
1960 }
1961
1962 // MHLO and XLA HLO disagree on the meaning of addition of `pred` / `i1`, so
1963 // there has to be a special case somewhere to account for the difference. To
1964 // get the expected behavior of an `AddOp` on `i1`, we have to use `xor`. Since
1965 // the majority of the conversion is generated code, we just sidestep it here
1966 // for this single case, and inline the code to emit an `xor`.
ExportXlaOperatorWrapped(mlir::Operation * inst,OpLoweringContext ctx)1967 LogicalResult ExportXlaOperatorWrapped(mlir::Operation* inst,
1968 OpLoweringContext ctx) {
1969 auto op = dyn_cast<mlir::mhlo::AddOp>(inst);
1970 if (op && op.getResult()
1971 .getType()
1972 .cast<mlir::TensorType>()
1973 .getElementType()
1974 .isSignlessInteger(1)) {
1975 auto& value_map = *ctx.values;
1976 auto result = op.getResult();
1977 xla::XlaOp xla_arg_0;
1978 if (failed(GetXlaOp(op.lhs(), value_map, &xla_arg_0, op)))
1979 return mlir::failure();
1980 xla::XlaOp xla_arg_1;
1981 if (failed(GetXlaOp(op.rhs(), value_map, &xla_arg_1, op)))
1982 return mlir::failure();
1983 auto xla_result = xla::Xor(Unwrap(xla_arg_0), Unwrap(xla_arg_1));
1984 value_map[result] = xla_result;
1985 return mlir::success();
1986 }
1987
1988 return ExportXlaOperator(inst, ctx);
1989 }
1990
Lower(mlir::Operation * inst,bool is_entry_function,llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering,xla::XlaOp * return_value)1991 LogicalResult ConvertToHloModule::Lower(
1992 mlir::Operation* inst, bool is_entry_function,
1993 llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
1994 xla::XlaBuilder* builder,
1995 ConvertToHloModule::ValueLoweringMap* value_lowering,
1996 xla::XlaOp* return_value) {
1997 // Explicitly fail for ops that are not supported for export.
1998 if (inst->getDialect() !=
1999 inst->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>() &&
2000 !mlir::isa<mlir::func::ConstantOp, mlir::arith::ConstantOp,
2001 mlir::func::CallOp, mlir::tensor::CastOp,
2002 mlir::func::ReturnOp>(inst)) {
2003 inst->emitOpError("unsupported op for export to XLA");
2004 return failure();
2005 }
2006
2007 *return_value = xla::XlaOp();
2008
2009 // See MlirToHloConversionOptions for more about layouts.
2010 auto propagate_layouts = [this](mlir::Operation* inst,
2011 xla::XlaOp xla_op) -> mlir::LogicalResult {
2012 if (options_.propagate_layouts) {
2013 auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
2014 ->mutable_shape();
2015 // TODO(kramm): merge this with ConvertLayout.
2016 *shape = ExtractXlaShape(inst).ToProto();
2017 }
2018
2019 return success();
2020 };
2021
2022 if (succeeded(
2023 ExportXlaOperatorWrapped(inst, {value_lowering, this, builder}))) {
2024 if (inst->getNumResults() == 1) {
2025 auto iter = value_lowering->find(inst->getResult(0));
2026 if (iter == value_lowering->end()) {
2027 inst->emitOpError(
2028 "inst has a result, but it's not found in value_lowering");
2029 return failure();
2030 }
2031 if (failed(propagate_layouts(inst, iter->second))) {
2032 return failure();
2033 }
2034 }
2035 // For infeed ops stemming back to InfeedDequeueTuple, respect the
2036 // layout attribute, and create the corresponding layout in hlo.
2037 if (isa<mhlo::InfeedOp>(inst)) {
2038 mlir::ArrayAttr layout =
2039 inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
2040
2041 if (layout) {
2042 // We propagate layout to the following three ops:
2043 // L1: For each data-result of mhlo.InfeedOp, we find the exported
2044 // xla::kGetTupleElement and propagate the layout.
2045 //
2046 // L2: For the token-result of mhlo.InfeedOp (result at last index),
2047 // we extract the xla::kInfeed op using the corresponding
2048 // xla::kGetTupleElement and propagate the layout to it.
2049 //
2050 // L3: In case there are non-zero data-results, there exists an
2051 // additional xla::kGetTupleElement accessing a tuple of the
2052 // data-results. We need to propagate the layout to that
2053 // xla::kGetTupleElement as well.
2054 auto num_results = inst->getNumResults();
2055 bool propagate_layout_to_data_tuple = true;
2056 for (unsigned i = 0; i < num_results; i++) {
2057 auto iter = value_lowering->find(inst->getResult(i));
2058 if (iter == value_lowering->end()) {
2059 inst->emitOpError() << "inst's result value at index " << i
2060 << " has no match in value_lowering";
2061 return failure();
2062 }
2063 auto xla_gte_op = iter->second;
2064 xla::HloInstructionProto* get_tuple_element_proto =
2065 xla::internal::XlaBuilderFriend::GetInstruction(xla_gte_op);
2066
2067 assert(xla::StringToHloOpcode(get_tuple_element_proto->opcode())
2068 .ValueOrDie() == xla::HloOpcode::kGetTupleElement &&
2069 "The token-result of mhlo.InfeedOp should be mapped to a "
2070 "xla::HloOpcode::kGetTupleElement");
2071
2072 if (i == num_results - 1) {
2073 // L2
2074 xla::HloInstructionProto* xla_infeed_op_proto =
2075 xla::internal::XlaBuilderFriend::GetInstructionByHandle(
2076 xla_gte_op.builder(),
2077 get_tuple_element_proto->operand_ids(0));
2078
2079 assert(xla::StringToHloOpcode(xla_infeed_op_proto->opcode())
2080 .ValueOrDie() == xla::HloOpcode::kInfeed &&
2081 "Expected xla::HloOpcode::kInfeed op");
2082
2083 auto* shape = xla_infeed_op_proto->mutable_shape();
2084 if (failed(ConvertInfeedtLayout(inst, layout, shape)))
2085 return failure();
2086
2087 } else {
2088 // L1
2089 auto* shape = get_tuple_element_proto->mutable_shape();
2090 if (failed(ConvertInfeedtLayout(inst, layout, shape, i)))
2091 return failure();
2092
2093 // L3
2094 if (propagate_layout_to_data_tuple) {
2095 xla::HloInstructionProto* data_tuple_proto =
2096 xla::internal::XlaBuilderFriend::GetInstructionByHandle(
2097 xla_gte_op.builder(),
2098 get_tuple_element_proto->operand_ids(0));
2099 auto* data_tuple_shape = data_tuple_proto->mutable_shape();
2100
2101 assert(xla::StringToHloOpcode(data_tuple_proto->opcode())
2102 .ValueOrDie() ==
2103 xla::HloOpcode::kGetTupleElement &&
2104 "Expected a xla:tupleOp for all the data results.");
2105 if (failed(ConvertInfeedtLayout(inst, layout, data_tuple_shape)))
2106 return failure();
2107 }
2108 propagate_layout_to_data_tuple = false;
2109 }
2110 }
2111 }
2112 }
2113 return success();
2114 }
2115
2116 auto& value_map = *value_lowering;
2117 ElementsAttr const_attr;
2118
2119 if (auto call_op = dyn_cast<mlir::func::CallOp>(inst)) {
2120 return LowerFunctionCall(call_op, builder, &value_map);
2121 }
2122
2123 if (auto op = dyn_cast<mlir::tensor::CastOp>(inst)) {
2124 Value operand = op.getOperand();
2125 auto ty = operand.getType().dyn_cast<ShapedType>();
2126 // If this was a cast from a static or bounded tensors, then it is a noop
2127 // for export to HLO and we can use the operand.
2128 if (!ty || !IsBoundedOrStatic(ty)) {
2129 inst->emitOpError()
2130 << "requires static or bounded operand for HLO translation";
2131 return failure();
2132 }
2133
2134 xla::XlaOp xla_operand;
2135 if (failed(GetXlaOp(operand, value_map, &xla_operand, op)))
2136 return failure();
2137 value_map[op.getResult()] = xla_operand;
2138 if (failed(propagate_layouts(inst, xla_operand))) {
2139 return failure();
2140 }
2141 return success();
2142 }
2143
2144 if (matchPattern(inst, m_Constant(&const_attr))) {
2145 if (!inst->getResult(0).getType().isa<ShapedType>()) {
2146 return inst->emitError(
2147 "expected shaped type during constant mhlo -> hlo translation");
2148 }
2149
2150 auto literal_or =
2151 CreateArrayLiteralFromAttr(const_attr, ExtractXlaShape(inst).layout());
2152 if (!literal_or.ok())
2153 return inst->emitError(literal_or.status().ToString());
2154 auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
2155 value_map[inst->getResult(0)] = constant;
2156
2157 return success();
2158 }
2159
2160 if (isa<mhlo::ReturnOp, mlir::func::ReturnOp>(inst)) {
2161 // Construct the return value for the function. If there is a single value
2162 // returned, then return it directly, else create a tuple and return.
2163 unsigned num_return_values = inst->getNumOperands();
2164 const bool has_ret_shardings =
2165 !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
2166 if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
2167 std::vector<xla::XlaOp> returns(num_return_values);
2168 for (OpOperand& ret : inst->getOpOperands()) {
2169 unsigned index = ret.getOperandNumber();
2170 xla::XlaOp operand;
2171 if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
2172 return failure();
2173
2174 returns[index] = operand;
2175 if (!is_entry_function || !has_ret_shardings) continue;
2176
2177 xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
2178 StatusOr<xla::XlaOp> reshape =
2179 ReshapeWithCorrectRepresentationAndSharding(
2180 builder, returns[index], return_shape,
2181 options_.layout_preference_fn, options_.shape_representation_fn,
2182 ret_shardings[index], /*fast_mem=*/false);
2183 if (!reshape.ok())
2184 return inst->emitError() << reshape.status().error_message();
2185
2186 returns[index] = reshape.ValueOrDie();
2187 }
2188
2189 if (has_ret_shardings) {
2190 xla::OpSharding sharding;
2191 sharding.set_type(xla::OpSharding::TUPLE);
2192 for (auto& ret_sharding : ret_shardings)
2193 *sharding.add_tuple_shardings() = *ret_sharding;
2194
2195 builder->SetSharding(sharding);
2196 }
2197
2198 *return_value = xla::Tuple(builder, returns);
2199 builder->ClearSharding();
2200 } else if (num_return_values == 1) {
2201 xla::XlaOp operand;
2202 if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
2203 return failure();
2204
2205 if (has_ret_shardings) {
2206 auto tuple = Tuple(builder, {operand});
2207 builder->SetSharding(*ret_shardings[0]);
2208 *return_value = GetTupleElement(tuple, 0);
2209 builder->ClearSharding();
2210 } else {
2211 *return_value = operand;
2212 }
2213 }
2214
2215 return success();
2216 }
2217
2218 inst->emitOpError() << "can't be translated to XLA HLO";
2219 return failure();
2220 }
2221
LowerFunctionCall(mlir::func::CallOp call_op,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering)2222 LogicalResult ConvertToHloModule::LowerFunctionCall(
2223 mlir::func::CallOp call_op, xla::XlaBuilder* builder,
2224 ConvertToHloModule::ValueLoweringMap* value_lowering) {
2225 auto& value_map = *value_lowering;
2226 mlir::func::FuncOp callee =
2227 module_.lookupSymbol<mlir::func::FuncOp>(call_op.getCallee());
2228 if (failed(RunOnFunction(callee))) return failure();
2229 std::vector<xla::XlaOp> operands;
2230 for (auto operand : call_op.getOperands()) {
2231 xla::XlaOp xla_operand;
2232 if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op)))
2233 return failure();
2234 operands.push_back(xla_operand);
2235 }
2236 // Each call to xla::Call would insert a copy of the computation to
2237 // the HLO. Thus each callsite would have a unique callee in the
2238 // exported HLO. HLO syntactically does not require all calls to have unique
2239 // callees, but eventually before lowering call graph is "flattened" to
2240 // make that true. This is done before lowering because buffer assignment
2241 // needs this invariant.
2242 xla::XlaOp call_result =
2243 xla::Call(builder, lowered_computation_[callee], operands);
2244 // Use GetTupleElement for multiple outputs
2245 unsigned num_results = call_op.getNumResults();
2246 if (num_results > 1) {
2247 for (unsigned i = 0; i != num_results; ++i) {
2248 value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i);
2249 }
2250 } else if (num_results == 1) {
2251 value_map[call_op.getResult(0)] = call_result;
2252 }
2253 return success();
2254 }
2255
RunOnFunction(mlir::func::FuncOp f)2256 LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
2257 if (lowered_computation_.count(f)) return success();
2258 if (!llvm::hasSingleElement(f)) {
2259 return f.emitError("only single block Function supported");
2260 }
2261
2262 // Create a sub-builder if this is not the main function.
2263 std::unique_ptr<xla::XlaBuilder> builder_up;
2264 bool entry_function = f.getName() == "main";
2265 if (!entry_function)
2266 builder_up = module_builder_.CreateSubBuilder(f.getName().str());
2267 auto& builder = entry_function ? module_builder_ : *builder_up;
2268
2269 xla::XlaComputation computation;
2270 std::vector<bool> entry_args_same_across_replicas;
2271 llvm::SmallVector<std::optional<xla::OpSharding>, 4> arg_shardings;
2272 llvm::SmallVector<std::optional<xla::OpSharding>, 4> ret_shardings;
2273 if (entry_function) {
2274 bool any_arg_replicated = false;
2275 entry_args_same_across_replicas.reserve(f.getNumArguments());
2276 for (int64_t i = 0; i < f.getNumArguments(); ++i) {
2277 auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kReplicationAttr);
2278 entry_args_same_across_replicas.push_back(attr != nullptr);
2279 any_arg_replicated |= entry_args_same_across_replicas.back();
2280 // Pass the alias info to the builder so that it will build the alias info
2281 // into the resulting HloModule.
2282 auto aliasing_output =
2283 f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
2284 if (!aliasing_output) continue;
2285 xla::ShapeIndex output_index;
2286 if ((return_tuple_ && entry_function) || f.getNumResults() != 1) {
2287 output_index = {aliasing_output.getInt()};
2288 } else {
2289 if (aliasing_output.getInt() != 0) {
2290 return f.emitError(
2291 "Aliasing output must be 0 if only one output exists");
2292 }
2293 output_index = {};
2294 }
2295 if (use_tuple_args_) {
2296 builder.SetUpAlias(output_index, /*param_number=*/0,
2297 /*param_index=*/{i});
2298 } else {
2299 builder.SetUpAlias(output_index, /*param_number=*/i,
2300 /*param_index=*/{});
2301 }
2302 }
2303 // Do not populate this field when nothing is replicated, since empty field
2304 // means no replication. This avoids the need for unrelated tests to handle
2305 // this field.
2306 if (!any_arg_replicated) entry_args_same_across_replicas.clear();
2307
2308 ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
2309 }
2310 if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function,
2311 false, entry_args_same_across_replicas,
2312 arg_shardings, ret_shardings,
2313 &computation))) {
2314 return failure();
2315 }
2316 lowered_computation_[f] = std::move(computation);
2317 return success();
2318 }
2319
SetEntryTupleShapesAndLeafReplication(Block * block,const std::vector<bool> & entry_args_same_across_replicas,llvm::SmallVectorImpl<xla::Shape> * arg_shapes,std::vector<bool> * leaf_replication)2320 LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
2321 Block* block, const std::vector<bool>& entry_args_same_across_replicas,
2322 llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
2323 std::vector<bool>* leaf_replication) {
2324 arg_shapes->reserve(block->getNumArguments());
2325 leaf_replication->reserve(block->getNumArguments());
2326 for (BlockArgument& arg : block->getArguments()) {
2327 arg_shapes->push_back(xla::TypeToShape(arg.getType()));
2328 xla::Shape& arg_shape = arg_shapes->back();
2329 auto layout_preference_status =
2330 options_.layout_preference_fn ? options_.layout_preference_fn(arg_shape)
2331 : XlaLayoutPreference::kNoPreference;
2332 if (!layout_preference_status.ok())
2333 return block->getParentOp()->emitError()
2334 << layout_preference_status.status().error_message();
2335
2336 auto arg_shape_status = options_.shape_representation_fn
2337 ? options_.shape_representation_fn(
2338 arg_shape, /*use_fast_memory=*/false,
2339 layout_preference_status.ValueOrDie())
2340 : arg_shape;
2341 if (!arg_shape_status.ok())
2342 return block->getParentOp()->emitError()
2343 << arg_shape_status.status().error_message();
2344
2345 arg_shape = std::move(arg_shape_status.ValueOrDie());
2346
2347 if (entry_args_same_across_replicas.empty()) continue;
2348 for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
2349 leaf_replication->push_back(
2350 entry_args_same_across_replicas[arg.getArgNumber()]);
2351 }
2352
2353 return success();
2354 }
2355
SetEntryTupleShardings(Block * block,xla::XlaBuilder * builder,llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,llvm::SmallVectorImpl<xla::Shape> * arg_shapes)2356 LogicalResult ConvertToHloModule::SetEntryTupleShardings(
2357 Block* block, xla::XlaBuilder* builder,
2358 llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
2359 llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
2360 if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
2361 xla::OpSharding sharding;
2362 sharding.set_type(xla::OpSharding::TUPLE);
2363 for (const auto& arg_sharding : llvm::enumerate(arg_shardings)) {
2364 auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value());
2365 if (!hlo_sharding.ok())
2366 return block->getParentOp()->emitError()
2367 << hlo_sharding.status().error_message();
2368
2369 auto status = RewriteLayoutWithShardedShape(
2370 hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
2371 options_.layout_preference_fn, options_.shape_representation_fn,
2372 &(*arg_shapes)[arg_sharding.index()]);
2373 if (!status.ok())
2374 return block->getParentOp()->emitError() << status.error_message();
2375
2376 *sharding.add_tuple_shardings() = *arg_sharding.value();
2377 }
2378
2379 builder->SetSharding(sharding);
2380 }
2381
2382 return success();
2383 }
2384
LowerBasicBlockAsFunction(Block * block,xla::XlaBuilder * builder,bool is_entry_function,bool ensure_single_arg,const std::vector<bool> & entry_args_same_across_replicas,llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,xla::XlaComputation * result,llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands)2385 LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
2386 Block* block, xla::XlaBuilder* builder, bool is_entry_function,
2387 bool ensure_single_arg,
2388 const std::vector<bool>& entry_args_same_across_replicas,
2389 llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
2390 llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
2391 xla::XlaComputation* result,
2392 llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands) {
2393 // Mapping from the Value to lowered XlaOp.
2394 ValueLoweringMap lowering;
2395
2396 // If using tuples as input, then there is only one input parameter that is a
2397 // tuple.
2398 if (is_entry_function && use_tuple_args_) {
2399 llvm::SmallVector<xla::Shape, 4> arg_shapes;
2400 std::vector<bool> leaf_replication;
2401 if (failed(SetEntryTupleShapesAndLeafReplication(
2402 block, entry_args_same_across_replicas, &arg_shapes,
2403 &leaf_replication)))
2404 return failure();
2405
2406 if (failed(
2407 SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
2408 return failure();
2409
2410 xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
2411 auto tuple =
2412 xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
2413 builder->ClearSharding();
2414
2415 bool set_tuple_element_sharding =
2416 !arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings);
2417 for (BlockArgument& arg : block->getArguments()) {
2418 if (set_tuple_element_sharding)
2419 builder->SetSharding(*arg_shardings[arg.getArgNumber()]);
2420 lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
2421 }
2422 builder->ClearSharding();
2423 } else {
2424 if (ensure_single_arg) {
2425 // Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp.
2426 llvm::SmallVector<xla::Shape, 4> arg_shapes;
2427
2428 auto args_size = block->getNumArguments();
2429 if (implicit_operands) args_size = implicit_operands->size();
2430
2431 arg_shapes.reserve(args_size);
2432 if (implicit_operands) {
2433 for (auto implicit_operand : *implicit_operands)
2434 arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType()));
2435 } else {
2436 for (BlockArgument& arg : block->getArguments())
2437 arg_shapes.push_back(xla::TypeToShape(arg.getType()));
2438 }
2439
2440 if (args_size > 1) {
2441 auto tuple = xla::Parameter(builder, 0,
2442 xla::ShapeUtil::MakeTupleShape(arg_shapes),
2443 "arg_tuple");
2444
2445 if (implicit_operands) {
2446 int arg_index = 0;
2447 for (auto implicit_operand : *implicit_operands)
2448 lowering[implicit_operand] =
2449 xla::GetTupleElement(tuple, arg_index++);
2450 } else {
2451 for (BlockArgument& arg : block->getArguments())
2452 lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
2453 }
2454 } else if (args_size == 1) {
2455 if (implicit_operands) {
2456 lowering[(*implicit_operands)[0]] =
2457 xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
2458 } else {
2459 lowering[block->getArgument(0)] =
2460 xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
2461 }
2462 } else {
2463 // Applicable only for IfOp or CaseOp. No implicit operands implies no
2464 // xla parameters. In this case, we create an empty tuple as the
2465 // block-parameter.
2466 xla::Parameter(builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes),
2467 "arg_empty_tuple");
2468 }
2469 } else {
2470 for (BlockArgument& arg : block->getArguments()) {
2471 auto num = arg.getArgNumber();
2472 xla::Shape shape = xla::TypeToShape(arg.getType());
2473 if (!arg_shardings.empty() && arg_shardings[num]) {
2474 builder->SetSharding(*arg_shardings[num]);
2475 }
2476 if (entry_args_same_across_replicas.empty()) {
2477 lowering[arg] =
2478 xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
2479 } else {
2480 lowering[arg] = xla::Parameter(
2481 builder, num, shape, absl::StrCat("Arg_", num),
2482 std::vector<bool>(entry_args_same_across_replicas[num],
2483 xla::ShapeUtil::GetLeafCount(shape)));
2484 }
2485 builder->ClearSharding();
2486 }
2487 }
2488 }
2489
2490 xla::XlaOp return_value;
2491 for (auto& inst : *block)
2492 if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
2493 &lowering, &return_value)))
2494 return failure();
2495
2496 // Build the XlaComputation and check for failures.
2497 auto computation_or =
2498 return_value.valid() ? builder->Build(return_value) : builder->Build();
2499 if (!computation_or.ok()) {
2500 block->back().emitError(
2501 llvm::Twine(computation_or.status().error_message()));
2502 return failure();
2503 }
2504 *result = std::move(computation_or.ValueOrDie());
2505 return success();
2506 }
2507
LowerRegionAsComputation(mlir::Region * region,xla::XlaComputation * func,llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands,bool ensure_single_arg)2508 LogicalResult ConvertToHloModule::LowerRegionAsComputation(
2509 mlir::Region* region, xla::XlaComputation* func,
2510 llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands,
2511 bool ensure_single_arg) {
2512 std::unique_ptr<xla::XlaBuilder> builder =
2513 module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
2514 return LowerBasicBlockAsFunction(®ion->front(), builder.get(),
2515 /*is_entry_function=*/false,
2516 /*ensure_single_arg*/ ensure_single_arg,
2517 /*entry_args_same_across_replicas=*/{},
2518 /*arg_shardings=*/{}, /*ret_shardings=*/{},
2519 func, implicit_operands);
2520 }
2521
AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto * binding,int arg_index,int32_t shape_index,int32_t padding_arg_index,bool use_tuple_args)2522 void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
2523 int arg_index, int32_t shape_index,
2524 int32_t padding_arg_index,
2525 bool use_tuple_args) {
2526 auto* entry = binding->add_entries();
2527 entry->set_target_param_dim_num(shape_index);
2528 if (use_tuple_args) {
2529 entry->set_target_param_num(0);
2530 entry->add_target_param_index(arg_index);
2531 entry->set_dynamic_param_num(0);
2532 entry->add_dynamic_param_index(padding_arg_index);
2533 } else {
2534 entry->set_target_param_num(arg_index);
2535 entry->set_dynamic_param_num(padding_arg_index);
2536 }
2537 }
2538
2539 // Runs the PrepareForExport pass on the ModuleOp.
PrepareForExport(mlir::ModuleOp module)2540 Status PrepareForExport(mlir::ModuleOp module) {
2541 // Prepare for export to XLA HLO.
2542 mlir::PassManager pm(module.getContext());
2543 pm.addNestedPass<mlir::func::FuncOp>(mhlo::CreatePrepareForExport());
2544 if (failed(pm.run(module)))
2545 return tensorflow::errors::Internal("Unable to optimize for XLA export");
2546 return ::tensorflow::OkStatus();
2547 }
2548
2549 } // namespace
2550
ConvertRegionToComputation(mlir::Region * region,xla::XlaComputation * func,MlirToHloConversionOptions options)2551 Status ConvertRegionToComputation(mlir::Region* region,
2552 xla::XlaComputation* func,
2553 MlirToHloConversionOptions options) {
2554 mlir::ModuleOp module;
2555 xla::XlaBuilder module_builder("main");
2556 ConvertToHloModule converter(module, module_builder, true, true, options);
2557 if (failed(converter.LowerRegionAsComputation(region, func)))
2558 return tensorflow::errors::Internal(
2559 "failed to convert region to computation");
2560 return ::tensorflow::OkStatus();
2561 }
2562
ConvertMlirHloToHlo(mlir::ModuleOp module,xla::HloProto * hlo_proto,bool use_tuple_args,bool return_tuple,MlirToHloConversionOptions options)2563 Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
2564 bool use_tuple_args, bool return_tuple,
2565 MlirToHloConversionOptions options) {
2566 TF_RETURN_IF_ERROR(PrepareForExport(module));
2567 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
2568 xla::XlaBuilder module_builder("main");
2569 ConvertToHloModule converter(module, module_builder, use_tuple_args,
2570 return_tuple, options);
2571 if (failed(converter.Run())) return diag_handler.ConsumeStatus();
2572 auto hlo_module = converter.ConsumeMainProto();
2573 StringRef module_name = module.getName() ? *module.getName() : "main";
2574 hlo_module.set_name(module_name.str());
2575 hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
2576 return ::tensorflow::OkStatus();
2577 }
2578
BuildHloFromMlirHlo(mlir::Block & block,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,MlirToHloConversionOptions options)2579 Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
2580 llvm::ArrayRef<xla::XlaOp> xla_params,
2581 std::vector<xla::XlaOp>& returns,
2582 MlirToHloConversionOptions options) {
2583 auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
2584 TF_RETURN_IF_ERROR(PrepareForExport(module));
2585 ConvertToHloModule converter(module, builder,
2586 /*use_tuple_args=*/false, /*return_tuple=*/false,
2587 options);
2588
2589 ConvertToHloModule::ValueLoweringMap lowering;
2590 // xla_params should only include non-constant parameters the block arguments
2591 // correspond to.
2592 if (xla_params.size() != block.getArguments().size())
2593 return tensorflow::errors::Internal("xla_params size (", xla_params.size(),
2594 ") != block arguments size (",
2595 block.getArguments().size(), ")");
2596 for (BlockArgument& arg : block.getArguments()) {
2597 auto num = arg.getArgNumber();
2598 lowering[arg] = xla_params[num];
2599 }
2600
2601 mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
2602 for (auto& inst : block) {
2603 if (isa<mhlo::ReturnOp, mlir::func::ReturnOp>(inst)) {
2604 returns.resize(inst.getNumOperands());
2605 for (OpOperand& ret : inst.getOpOperands()) {
2606 unsigned index = ret.getOperandNumber();
2607 xla::XlaOp operand;
2608 if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
2609 return diag_handler.ConsumeStatus();
2610 returns[index] = operand;
2611 }
2612 } else {
2613 xla::XlaOp return_value;
2614 if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
2615 /*ret_shardings=*/{}, &builder, &lowering,
2616 &return_value)))
2617 return diag_handler.ConsumeStatus();
2618 }
2619 }
2620
2621 return ::tensorflow::OkStatus();
2622 }
2623
2624 } // namespace mlir
2625