xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <string>
21 
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/MemoryBuffer.h"
29 #include "llvm/Support/SMLoc.h"
30 #include "llvm/Support/SourceMgr.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
33 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
34 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
36 #include "mlir/IR/Attributes.h"  // from @llvm-project
37 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
40 #include "mlir/IR/Location.h"  // from @llvm-project
41 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
42 #include "mlir/IR/Matchers.h"  // from @llvm-project
43 #include "mlir/IR/Operation.h"  // from @llvm-project
44 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
45 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
46 #include "mlir/Pass/Pass.h"  // from @llvm-project
47 #include "mlir/Pass/PassManager.h"  // from @llvm-project
48 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
49 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/utils/name_utils.h"
51 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
52 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
53 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
54 #include "tensorflow/compiler/xla/client/lib/matrix.h"
55 #include "tensorflow/compiler/xla/client/lib/quantize.h"
56 #include "tensorflow/compiler/xla/client/lib/slicing.h"
57 #include "tensorflow/compiler/xla/client/xla_builder.h"
58 #include "tensorflow/compiler/xla/comparison_util.h"
59 #include "tensorflow/compiler/xla/literal_util.h"
60 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
61 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
62 #include "tensorflow/compiler/xla/service/hlo_module.h"
63 #include "tensorflow/compiler/xla/service/hlo_parser.h"
64 #include "tensorflow/compiler/xla/shape_util.h"
65 #include "tensorflow/compiler/xla/status_macros.h"
66 #include "tensorflow/compiler/xla/xla_data.pb.h"
67 #include "tensorflow/core/framework/tensor_shape.h"
68 #include "tensorflow/core/framework/types.pb.h"
69 #include "tensorflow/core/platform/errors.h"
70 #include "tensorflow/stream_executor/lib/statusor.h"
71 
72 using ::int64_t;
73 using ::stream_executor::port::StatusOr;
74 using ::tensorflow::int16;
75 using ::tensorflow::int32;
76 using ::tensorflow::int8;
77 using ::tensorflow::uint16;
78 using ::tensorflow::uint32;
79 using ::tensorflow::uint64;
80 using ::tensorflow::uint8;
81 
82 constexpr char kShapeIndicesAttr[] = "shape_indices";
83 constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
84 constexpr char kShardingAttr[] = "mhlo.sharding";
85 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
86 constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas";
87 
88 // Array attribute. Same shape as infeed result, but contains a
89 // minor_to_major array for every tensor.
90 constexpr char kLayoutAttr[] = "layout";
91 constexpr char kDefaultLayoutAttrName[] = "xla_shape";
92 
93 // Passes through everything except for unique_ptr, on which it calls get().
94 // This exists to allow the generated code to call XLA functions that take a raw
95 // pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
96 // as a pointer and there is otherwise no way to avoid a memory leak.
97 template <typename T>
Unwrap(T t)98 T Unwrap(T t) {
99   return t;
100 }
101 
102 template <typename T>
Unwrap(const std::unique_ptr<T> & t)103 T* Unwrap(const std::unique_ptr<T>& t) {
104   return t.get();
105 }
106 
GetXlaOp(mlir::Value val,const llvm::DenseMap<mlir::Value,xla::XlaOp> & val_map,xla::XlaOp * result,mlir::Operation * op)107 static mlir::LogicalResult GetXlaOp(
108     mlir::Value val, const llvm::DenseMap<mlir::Value, xla::XlaOp>& val_map,
109     xla::XlaOp* result, mlir::Operation* op) {
110   auto iter = val_map.find(val);
111   if (iter == val_map.end()) {
112     return op->emitOpError(
113         "requires all operands to be defined in the parent region for export");
114   }
115   *result = iter->second;
116   return mlir::success();
117 }
118 
IsBoundedOrStatic(mlir::Type ty)119 bool IsBoundedOrStatic(mlir::Type ty) {
120   auto ranked_ty = ty.dyn_cast_or_null<mlir::RankedTensorType>();
121   if (!ranked_ty) return false;
122 
123   if (ranked_ty.hasStaticShape()) return true;
124 
125   auto encoding = ranked_ty.getEncoding()
126                       .dyn_cast_or_null<mlir::mhlo::TypeExtensionsAttr>();
127   if (!encoding || encoding.getBounds().empty()) return false;
128 
129   int64_t rank = ranked_ty.getRank();
130   for (int64_t dim = 0; dim < rank; ++dim) {
131     if (ranked_ty.isDynamicDim(dim) &&
132         encoding.getBounds()[dim] == mlir::ShapedType::kDynamicSize)
133       return false;
134   }
135   return true;
136 }
137 
138 // Convert APInt into an int.
139 // TODO(hpucha): This should be consolidated into a general place.
ConvertAPInt(llvm::APInt i)140 static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
141 
Convertuint32_t(uint32_t i)142 static uint32_t Convertuint32_t(uint32_t i) { return i; }
Convertuint64_t(uint64_t i)143 static uint64_t Convertuint64_t(uint64_t i) { return i; }
144 
145 // Convert APFloat to double.
ConvertAPFloat(llvm::APFloat value)146 static double ConvertAPFloat(llvm::APFloat value) {
147   const auto& semantics = value.getSemantics();
148   bool losesInfo = false;
149   if (&semantics != &llvm::APFloat::IEEEdouble())
150     value.convert(llvm::APFloat::IEEEdouble(),
151                   llvm::APFloat::rmNearestTiesToEven, &losesInfo);
152   return value.convertToDouble();
153 }
154 
Convertbool(bool value)155 static inline bool Convertbool(bool value) { return value; }
156 
ConvertStringRef(mlir::StringRef value)157 static absl::string_view ConvertStringRef(mlir::StringRef value) {
158   return {value.data(), value.size()};
159 }
160 
ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr)161 static std::vector<int64_t> ConvertDenseIntAttr(
162     mlir::DenseIntElementsAttr attr) {
163   auto values = attr.getValues<int64_t>();
164   return {values.begin(), values.end()};
165 }
166 
ConvertDenseIntAttr(llvm::Optional<mlir::DenseIntElementsAttr> attr)167 static std::vector<int64_t> ConvertDenseIntAttr(
168     llvm::Optional<mlir::DenseIntElementsAttr> attr) {
169   if (!attr) return {};
170   return ConvertDenseIntAttr(*attr);
171 }
172 
173 // Converts the broadcast_dimensions attribute into a vector of dimension
174 // numbers (empty if the attribute is absent).
Convert_broadcast_dimensions(llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions)175 static std::vector<int64_t> Convert_broadcast_dimensions(
176     llvm::Optional<mlir::DenseIntElementsAttr> broadcast_dimensions) {
177   if (!broadcast_dimensions.has_value()) return {};
178 
179   return ConvertDenseIntAttr(*broadcast_dimensions);
180 }
181 
182 // Converts StringRef to xla FftType enum
Convert_fft_type(mlir::mhlo::FftType fft_type)183 static xla::FftType Convert_fft_type(mlir::mhlo::FftType fft_type) {
184   xla::FftType fft_type_enum;
185   // Illegal fft_type string would be caught by the verifier, so 'FftType_Parse'
186   // call below should never return false.
187   if (!FftType_Parse(std::string(mlir::mhlo::stringifyFftType(fft_type)),
188                      &fft_type_enum))
189     return xla::FftType::FFT;
190   return fft_type_enum;
191 }
192 
Convert_padding(llvm::Optional<mlir::DenseIntElementsAttr> padding)193 static std::vector<std::pair<int64_t, int64_t>> Convert_padding(
194     llvm::Optional<mlir::DenseIntElementsAttr> padding) {
195   return xla::ConvertNx2Attribute(padding).ValueOrDie();
196 }
197 
Convert_use_global_device_ids(llvm::Optional<bool> use_global_device_ids)198 static std::optional<bool> Convert_use_global_device_ids(
199     llvm::Optional<bool> use_global_device_ids) {
200   if (!use_global_device_ids) return {};
201   return *use_global_device_ids;
202 }
203 
Convert_source_target_pairs(llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs)204 static std::vector<std::pair<int64_t, int64_t>> Convert_source_target_pairs(
205     llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) {
206   return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie();
207 }
208 
Convert_replica_groups(mlir::DenseIntElementsAttr groups)209 static std::vector<xla::ReplicaGroup> Convert_replica_groups(
210     mlir::DenseIntElementsAttr groups) {
211   return xla::ConvertReplicaGroups(groups).ValueOrDie();
212 }
213 
214 // Converts types and corresponding layouts into xla shapes with layouts.
ConvertTypesToShapesWithLayout(mlir::TypeRange value_types,mlir::ArrayAttr layouts)215 static std::vector<xla::Shape> ConvertTypesToShapesWithLayout(
216     mlir::TypeRange value_types, mlir::ArrayAttr layouts) {
217   std::vector<xla::Shape> shapes_with_layout;
218   for (auto type_and_layout : llvm::zip(value_types, layouts)) {
219     mlir::Type type = std::get<0>(type_and_layout);
220     mlir::Attribute layout = std::get<1>(type_and_layout);
221     assert(!type.isa<mlir::TupleType>() &&
222            "Exporting layout for tuples is not implemented yet");
223     shapes_with_layout.emplace_back(xla::TypeToShape(type));
224     auto& shape = shapes_with_layout.back();
225     shape.mutable_layout()->clear_minor_to_major();
226     for (auto l : layout.cast<mlir::DenseIntElementsAttr>()) {
227       shape.mutable_layout()->mutable_minor_to_major()->push_back(
228           l.getSExtValue());
229     }
230   }
231   return shapes_with_layout;
232 }
233 
234 // CustomCallOp result can be of tuple type to pack multiple results into one
235 // value. If the custom call result is a tuple, then result layouts represent
236 // the layout of each element of the tuple. Nested tuples are currently not
237 // supported for export.
GetCustomCallResultShapeWithLayout(mlir::Type type,mlir::ArrayAttr layouts)238 static xla::Shape GetCustomCallResultShapeWithLayout(mlir::Type type,
239                                                      mlir::ArrayAttr layouts) {
240   auto tuple_type = type.dyn_cast<mlir::TupleType>();
241   if (!tuple_type) return ConvertTypesToShapesWithLayout({type}, layouts)[0];
242 
243   std::vector<xla::Shape> shapes_with_layouts =
244       ConvertTypesToShapesWithLayout(tuple_type.getTypes(), layouts);
245   return xla::ShapeUtil::MakeTupleShape(shapes_with_layouts);
246 }
247 
248 // Converts StringRef to xla Transpose enum.
Convert_transpose_a(mlir::mhlo::Transpose transpose)249 static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
250     mlir::mhlo::Transpose transpose) {
251   return xla::ConvertTranspose(mlir::mhlo::stringifyTranspose(transpose))
252       .ValueOrDie();
253 }
254 
ExtractLayout(mlir::Operation * op,int rank,llvm::StringRef attr_name=kDefaultLayoutAttrName)255 static xla::Layout ExtractLayout(
256     mlir::Operation* op, int rank,
257     llvm::StringRef attr_name = kDefaultLayoutAttrName) {
258   if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(attr_name)) {
259     llvm::SmallVector<int64_t, 4> minor_to_major;
260     DCHECK_EQ(rank, attr.size());
261     minor_to_major.reserve(attr.size());
262     for (const llvm::APInt& i : attr) {
263       minor_to_major.push_back(i.getZExtValue());
264     }
265     return xla::LayoutUtil::MakeLayout(minor_to_major);
266   }
267   return xla::LayoutUtil::MakeDescendingLayout(rank);
268 }
269 
ExtractXlaShape(mlir::Operation * op)270 static xla::Shape ExtractXlaShape(mlir::Operation* op) {
271   if (auto attr = op->getAttrOfType<mlir::StringAttr>(kDefaultLayoutAttrName)) {
272     return *xla::ParseShape(
273         absl::string_view(attr.getValue().data(), attr.getValue().size()));
274   } else {
275     std::vector<xla::Shape> subshapes;
276     for (mlir::Value result : op->getResults()) {
277       subshapes.push_back(xla::TypeToShape(result.getType()));
278     }
279     if (subshapes.size() > 1) {
280       return xla::ShapeUtil::MakeTupleShape(subshapes);
281     }
282     return subshapes[0];
283   }
284 }
285 
286 #define I64_ELEMENTS_ATTR_TO_VECTOR(attribute)                \
287   static std::vector<int64_t> Convert_##attribute(            \
288       llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
289     return ConvertDenseIntAttr(attribute);                    \
290   }
291 
292 I64_ELEMENTS_ATTR_TO_VECTOR(broadcast_sizes);
293 I64_ELEMENTS_ATTR_TO_VECTOR(permutation);
294 I64_ELEMENTS_ATTR_TO_VECTOR(start_indices);
295 I64_ELEMENTS_ATTR_TO_VECTOR(limit_indices);
296 I64_ELEMENTS_ATTR_TO_VECTOR(strides);
297 I64_ELEMENTS_ATTR_TO_VECTOR(slice_sizes);
298 I64_ELEMENTS_ATTR_TO_VECTOR(fft_length);
299 I64_ELEMENTS_ATTR_TO_VECTOR(dimensions);
300 I64_ELEMENTS_ATTR_TO_VECTOR(window_strides);
301 I64_ELEMENTS_ATTR_TO_VECTOR(lhs_dilation);
302 I64_ELEMENTS_ATTR_TO_VECTOR(rhs_dilation);
303 
304 #undef I64_ELEMENTS_ATTR_TO_VECTOR
305 
306 #define BOOL_ELEMENTS_ATTR_TO_VECTOR(attribute)            \
307   static std::vector<bool> Convert_##attribute(            \
308       llvm::Optional<mlir::DenseElementsAttr> attribute) { \
309     if (!attribute) return {};                             \
310     auto values = attribute->getValues<bool>();            \
311     return {values.begin(), values.end()};                 \
312   }
313 
314 BOOL_ELEMENTS_ATTR_TO_VECTOR(window_reversal);
315 
316 #undef BOOL_ELEMENTS_ATTR_TO_VECTOR
317 
Convert_ArrayRef(llvm::ArrayRef<int64_t> values)318 static std::vector<int64_t> Convert_ArrayRef(llvm::ArrayRef<int64_t> values) {
319   return {values.begin(), values.end()};
320 }
321 
322 // Converts the precision config array of strings attribute into the
323 // corresponding XLA proto. All the strings are assumed to be valid names of the
324 // Precision enum. This should have been checked in the op verify method.
Convert_precision_config(llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr)325 static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
326     llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
327   if (!optional_precision_config_attr.has_value()) return nullptr;
328 
329   auto precision_config = std::make_unique<xla::PrecisionConfig>();
330   for (auto attr : optional_precision_config_attr.getValue()) {
331     xla::PrecisionConfig::Precision p;
332     auto operand_precision =
333         mlir::mhlo::stringifyPrecision(
334             attr.cast<mlir::mhlo::PrecisionAttr>().getValue())
335             .str();
336     // TODO(jpienaar): Update this to ensure this is captured by verify.
337     if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
338       precision_config->add_operand_precision(p);
339     } else {
340       auto* context = attr.getContext();
341       mlir::emitError(mlir::UnknownLoc::get(context))
342           << "unexpected operand precision " << operand_precision;
343       return nullptr;
344     }
345   }
346 
347   return precision_config;
348 }
349 
Convert_dot_dimension_numbers(mlir::mhlo::DotDimensionNumbersAttr dot_dimension_numbers_attr)350 static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
351     mlir::mhlo::DotDimensionNumbersAttr dot_dimension_numbers_attr) {
352   xla::DotDimensionNumbers dot_dimension_numbers;
353 
354   auto rhs_contracting_dimensions =
355       dot_dimension_numbers_attr.getRhsContractingDimensions();
356   auto lhs_contracting_dimensions =
357       dot_dimension_numbers_attr.getLhsContractingDimensions();
358   auto rhs_batch_dimensions =
359       dot_dimension_numbers_attr.getRhsBatchingDimensions();
360   auto lhs_batch_dimensions =
361       dot_dimension_numbers_attr.getLhsBatchingDimensions();
362 
363   for (const auto& val : rhs_contracting_dimensions) {
364     dot_dimension_numbers.add_rhs_contracting_dimensions(val);
365   }
366   for (const auto& val : lhs_contracting_dimensions) {
367     dot_dimension_numbers.add_lhs_contracting_dimensions(val);
368   }
369 
370   for (const auto& val : rhs_batch_dimensions) {
371     dot_dimension_numbers.add_rhs_batch_dimensions(val);
372   }
373 
374   for (const auto& val : lhs_batch_dimensions) {
375     dot_dimension_numbers.add_lhs_batch_dimensions(val);
376   }
377 
378   return dot_dimension_numbers;
379 }
380 
Convert_dimension_numbers(mlir::mhlo::ConvDimensionNumbersAttr input)381 static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
382     mlir::mhlo::ConvDimensionNumbersAttr input) {
383   return xla::ConvertConvDimensionNumbers(input);
384 }
385 
Convert_channel_handle(mlir::mhlo::ChannelHandleAttr attr)386 xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandleAttr attr) {
387   xla::ChannelHandle channel_handle;
388   channel_handle.set_handle(attr.getHandle());
389   channel_handle.set_type(
390       static_cast<xla::ChannelHandle::ChannelType>(attr.getType()));
391   return channel_handle;
392 }
393 
Convert_channel_handle(llvm::Optional<mlir::mhlo::ChannelHandleAttr> attr)394 std::optional<xla::ChannelHandle> Convert_channel_handle(
395     llvm::Optional<mlir::mhlo::ChannelHandleAttr> attr) {
396   if (!attr.has_value()) return std::nullopt;
397   return Convert_channel_handle(attr.getValue());
398 }
399 
400 // Converts the comparison_direction string attribute into the XLA enum. The
401 // string is assumed to correspond to exactly one of the allowed strings
402 // representing the enum. This should have been checked in the op verify method.
Convert_comparison_direction(llvm::StringRef comparison_direction_string)403 static xla::ComparisonDirection Convert_comparison_direction(
404     llvm::StringRef comparison_direction_string) {
405   return xla::StringToComparisonDirection(comparison_direction_string.str())
406       .ValueOrDie();
407 }
408 
Convert_dimension_numbers(mlir::mhlo::GatherDimensionNumbersAttr input)409 static xla::GatherDimensionNumbers Convert_dimension_numbers(
410     mlir::mhlo::GatherDimensionNumbersAttr input) {
411   xla::GatherDimensionNumbers output;
412 
413   auto offset_dims = input.getOffsetDims();
414   std::copy(offset_dims.begin(), offset_dims.end(),
415             tensorflow::protobuf::RepeatedFieldBackInserter(
416                 output.mutable_offset_dims()));
417 
418   auto collapsed_slice_dims = input.getCollapsedSliceDims();
419   std::copy(collapsed_slice_dims.begin(), collapsed_slice_dims.end(),
420             tensorflow::protobuf::RepeatedFieldBackInserter(
421                 output.mutable_collapsed_slice_dims()));
422 
423   auto start_index_map = input.getStartIndexMap();
424   std::copy(start_index_map.begin(), start_index_map.end(),
425             tensorflow::protobuf::RepeatedFieldBackInserter(
426                 output.mutable_start_index_map()));
427 
428   output.set_index_vector_dim(input.getIndexVectorDim());
429   return output;
430 }
431 
Convert_scatter_dimension_numbers(mlir::mhlo::ScatterDimensionNumbersAttr input)432 static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
433     mlir::mhlo::ScatterDimensionNumbersAttr input) {
434   xla::ScatterDimensionNumbers output;
435 
436   auto update_window_dims = input.getUpdateWindowDims();
437   std::copy(update_window_dims.begin(), update_window_dims.end(),
438             tensorflow::protobuf::RepeatedFieldBackInserter(
439                 output.mutable_update_window_dims()));
440 
441   auto inserted_window_dims = input.getInsertedWindowDims();
442   std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
443             tensorflow::protobuf::RepeatedFieldBackInserter(
444                 output.mutable_inserted_window_dims()));
445 
446   auto scatter_dims_to_operand_dims = input.getScatterDimsToOperandDims();
447   std::copy(scatter_dims_to_operand_dims.begin(),
448             scatter_dims_to_operand_dims.end(),
449             tensorflow::protobuf::RepeatedFieldBackInserter(
450                 output.mutable_scatter_dims_to_operand_dims()));
451 
452   output.set_index_vector_dim(input.getIndexVectorDim());
453   return output;
454 }
455 
456 // Extracts sharding from attribute string.
CreateOpShardingFromStringRef(llvm::StringRef sharding)457 static std::optional<xla::OpSharding> CreateOpShardingFromStringRef(
458     llvm::StringRef sharding) {
459   xla::OpSharding sharding_proto;
460   if (!sharding_proto.ParseFromString(sharding.str())) return std::nullopt;
461   return sharding_proto;
462 }
463 
464 // Returns an OpSharding proto from the "sharding" attribute of the op. If the
465 // op doesn't have a sharding attribute or the sharding attribute is invalid,
466 // returns std::nullopt.
CreateOpShardingFromAttribute(mlir::Operation * op)467 static std::optional<xla::OpSharding> CreateOpShardingFromAttribute(
468     mlir::Operation* op) {
469   auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
470   if (!sharding) return std::nullopt;
471   return CreateOpShardingFromStringRef(sharding.getValue());
472 }
473 
474 // Returns a FrontendAttributes proto from the "frontend_attributes" attribute
475 // of the op. An empty FrontendAttributes proto is returned if an op does not
476 // have frontend attributes.
CreateOpFrontendAttributesFromAttribute(mlir::Operation * op)477 static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
478     mlir::Operation* op) {
479   xla::FrontendAttributes frontend_attributes;
480   auto frontend_attributes_dict =
481       op->getAttrOfType<mlir::DictionaryAttr>(kFrontendAttributesAttr);
482 
483   if (!frontend_attributes_dict) return frontend_attributes;
484 
485   for (const auto& attr : frontend_attributes_dict)
486     if (auto value_str_attr = attr.getValue().dyn_cast<mlir::StringAttr>())
487       frontend_attributes.mutable_map()->insert(
488           {attr.getName().str(), value_str_attr.getValue().str()});
489 
490   return frontend_attributes;
491 }
492 
493 // Returns a OpMetadata proto based on the location of the op. If the location
494 // is unknown, an empty proto is returned. `op_name` are populated with the op
495 // location (converted). FileLineColLoc locations are populated by taking the
496 // file name and line number, and populating `source_file` and `source_line`
497 // respectively.
CreateOpMetadataFromLocation(mlir::Operation * op,mlir::MlirToHloConversionOptions options)498 static xla::OpMetadata CreateOpMetadataFromLocation(
499     mlir::Operation* op, mlir::MlirToHloConversionOptions options) {
500   xla::OpMetadata metadata;
501   mlir::Location loc = op->getLoc();
502   if (loc.isa<mlir::UnknownLoc>()) return metadata;
503 
504   std::string name = mlir::GetNameFromLoc(loc);
505   if (options.legalize_node_names) {
506     mlir::LegalizeNodeName(name);
507   }
508   metadata.set_op_name(name);
509   std::string op_type = mlir::GetOpTypeFromLoc(loc);
510   mlir::LegalizeNodeName(op_type);
511   metadata.set_op_type(op_type);
512 
513   if (auto name_loc = op->getLoc().dyn_cast<mlir::NameLoc>()) {
514     loc = name_loc.getChildLoc();
515     if (loc.isa<mlir::UnknownLoc>()) return metadata;
516   }
517 
518   if (auto file_line_col_loc = loc.dyn_cast<mlir::FileLineColLoc>()) {
519     metadata.set_source_file(file_line_col_loc.getFilename().str());
520     metadata.set_source_line(file_line_col_loc.getLine());
521   }
522 
523   return metadata;
524 }
525 
526 // Checks if all shardings are set.
AllOptionalShardingsAreSet(llvm::ArrayRef<std::optional<xla::OpSharding>> shardings)527 static bool AllOptionalShardingsAreSet(
528     llvm::ArrayRef<std::optional<xla::OpSharding>> shardings) {
529   return llvm::all_of(shardings,
530                       [](const std::optional<xla::OpSharding>& sharding) {
531                         return sharding.has_value();
532                       });
533 }
534 
535 // Extracts argument and result shardings from function.
ExtractShardingsFromFunction(mlir::func::FuncOp function,llvm::SmallVectorImpl<std::optional<xla::OpSharding>> * arg_shardings,llvm::SmallVectorImpl<std::optional<xla::OpSharding>> * ret_shardings)536 static void ExtractShardingsFromFunction(
537     mlir::func::FuncOp function,
538     llvm::SmallVectorImpl<std::optional<xla::OpSharding>>* arg_shardings,
539     llvm::SmallVectorImpl<std::optional<xla::OpSharding>>* ret_shardings) {
540   arg_shardings->resize(function.getNumArguments(),
541                         std::optional<xla::OpSharding>());
542   for (int i = 0, end = function.getNumArguments(); i < end; ++i)
543     if (auto sharding =
544             function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
545       (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
546 
547   ret_shardings->resize(function.getNumResults(),
548                         std::optional<xla::OpSharding>());
549   for (int i = 0, end = function.getNumResults(); i < end; ++i)
550     if (auto sharding =
551             function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
552       (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
553 }
554 
555 namespace mlir {
556 namespace {
557 class ConvertToHloModule {
558  public:
559   using ValueLoweringMap = llvm::DenseMap<Value, xla::XlaOp>;
560   using FunctionLoweringMap =
561       llvm::DenseMap<mlir::func::FuncOp, xla::XlaComputation>;
562 
563   // If use_tuple_args is true, then the entry function's arguments are
564   // converted to a tuple and passed as a single parameter.
565   // Similarly, if return tuple is true, then the entry function's return values
566   // are converted to a tuple even when there is only a single return value.
567   // Multiple return values are always converted to a tuple and returned as a
568   // single value.
ConvertToHloModule(mlir::ModuleOp module,xla::XlaBuilder & module_builder,bool use_tuple_args,bool return_tuple,MlirToHloConversionOptions options)569   explicit ConvertToHloModule(mlir::ModuleOp module,
570                               xla::XlaBuilder& module_builder,
571                               bool use_tuple_args, bool return_tuple,
572                               MlirToHloConversionOptions options)
573       : module_(module),
574         module_builder_(module_builder),
575         use_tuple_args_(use_tuple_args),
576         return_tuple_(return_tuple),
577         options_(options) {}
578 
579   // Perform the lowering to XLA. This function returns failure if an error was
580   // encountered.
581   //
582   // TODO(hinsu): Check for dynamic shapes and exit instead of crashing.
Run()583   LogicalResult Run() {
584     auto main = module_.lookupSymbol<mlir::func::FuncOp>("main");
585     if (!main)
586       return module_.emitError(
587           "conversion requires module with `main` function");
588 
589     for (auto func : module_.getOps<func::FuncOp>()) {
590       if (func.empty()) continue;
591       if (failed(RunOnFunction(func))) return failure();
592     }
593     return success();
594   }
595 
596   // Lower a specific function to HLO.
597   LogicalResult RunOnFunction(mlir::func::FuncOp f);
598 
599   // Lower a `mlir::Region` to a `XlaComputation`
600   LogicalResult LowerRegionAsComputation(
601       mlir::Region* region, xla::XlaComputation* func,
602       llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands =
603           llvm::None,
604       bool ensure_single_arg = false);
605 
606   // Lower a single `Block` to a `XlaComputation`
607   LogicalResult LowerBasicBlockAsFunction(
608       Block* block, xla::XlaBuilder* builder, bool is_entry_function,
609       bool ensure_single_arg,
610       const std::vector<bool>& entry_args_same_across_replicas,
611       llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
612       llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
613       xla::XlaComputation* result,
614       llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands =
615           llvm::None);
616 
ConsumeMainProto()617   ::xla::HloModuleProto ConsumeMainProto() {
618     auto main = module_.lookupSymbol<mlir::func::FuncOp>("main");
619     // This is an invariant check as Run returns failure if there is no main
620     // function and so the main proto shouldn't be consumed in that case.
621     CHECK(main) << "requires module to have main function";  // Crash Ok.
622     return lowered_computation_[main].proto();
623   }
624 
625   // Lower function call to HLO call instruction
626   LogicalResult LowerFunctionCall(
627       mlir::func::CallOp call_op, xla::XlaBuilder* builder,
628       ConvertToHloModule::ValueLoweringMap* value_lowering);
629 
630   // Look up a symbol with the specified name, returning null if no such name
631   // exists.
LookUpSymbol(FlatSymbolRefAttr symbol)632   func::FuncOp LookUpSymbol(FlatSymbolRefAttr symbol) {
633     return module_.lookupSymbol<mlir::func::FuncOp>(symbol);
634   }
635 
636   // Get Reference to lowered XLA computation for a function.
GetLoweredComputation(func::FuncOp func)637   xla::XlaComputation& GetLoweredComputation(func::FuncOp func) {
638     return lowered_computation_[func];
639   }
640 
641   LogicalResult Lower(
642       mlir::Operation* inst, bool is_entry_function,
643       llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
644       xla::XlaBuilder* builder,
645       ConvertToHloModule::ValueLoweringMap* value_lowering,
646       xla::XlaOp* return_value);
647 
GetOptions() const648   const MlirToHloConversionOptions& GetOptions() const { return options_; }
649 
650  private:
651   LogicalResult SetEntryTupleShapesAndLeafReplication(
652       Block* block, const std::vector<bool>& entry_args_same_across_replicas,
653       llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
654       std::vector<bool>* leaf_replication);
655 
656   LogicalResult SetEntryTupleShardings(
657       Block* block, xla::XlaBuilder* builder,
658       llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
659       llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
660 
661   // The module being lowered.
662   mlir::ModuleOp module_;
663 
664   // The top-level XlaBuilder.
665   xla::XlaBuilder& module_builder_;
666 
667   // Map between function and lowered computation.
668   FunctionLoweringMap lowered_computation_;
669 
670   // Whether the entry function should take a single tuple as input.
671   bool use_tuple_args_;
672 
673   // Whether to always return a tuple.
674   bool return_tuple_;
675 
676   // Unique suffix to give to the name of the next lowered region.
677   size_t region_id_ = 0;
678 
679   MlirToHloConversionOptions options_;
680 };
681 
682 }  // namespace
683 }  // namespace mlir
684 
685 namespace {
686 
687 struct OpLoweringContext {
688   llvm::DenseMap<mlir::Value, xla::XlaOp>* values;
689   mlir::ConvertToHloModule* converter;
690   xla::XlaBuilder* builder;
691 };
692 
GetTuple(mlir::Operation * op,mlir::Operation::operand_range values,OpLoweringContext ctx,llvm::SmallVectorImpl<xla::XlaOp> & results)693 mlir::LogicalResult GetTuple(mlir::Operation* op,
694                              mlir::Operation::operand_range values,
695                              OpLoweringContext ctx,
696                              llvm::SmallVectorImpl<xla::XlaOp>& results) {
697   results.reserve(values.size());
698   for (mlir::Value value : values) {
699     if (failed(GetXlaOp(value, *ctx.values, &results.emplace_back(), op)))
700       return mlir::failure();
701   }
702   return mlir::success();
703 }
704 
GetXlaOps(mlir::Operation * op,llvm::ArrayRef<mlir::Value> values,OpLoweringContext ctx,llvm::SmallVectorImpl<xla::XlaOp> & results)705 mlir::LogicalResult GetXlaOps(mlir::Operation* op,
706                               llvm::ArrayRef<mlir::Value> values,
707                               OpLoweringContext ctx,
708                               llvm::SmallVectorImpl<xla::XlaOp>& results) {
709   results.reserve(values.size());
710   for (mlir::Value value : values) {
711     if (failed(GetXlaOp(value, *ctx.values, &results.emplace_back(), op)))
712       return mlir::failure();
713   }
714   return mlir::success();
715 }
716 
717 }  // namespace
718 
719 namespace mlir {
720 namespace mhlo {
721 namespace {
722 
ExportXlaOp(ComputeReshapeShapeOp,OpLoweringContext)723 LogicalResult ExportXlaOp(ComputeReshapeShapeOp, OpLoweringContext) {
724   // This op has no expression in the legacy export format. It can be expanded
725   // to a sequence of operations if needed in the future, but would feed into
726   // ops creating unsupported dynamic shapes.
727   return failure();
728 }
729 
ExportXlaOp(CstrReshapableOp,OpLoweringContext)730 LogicalResult ExportXlaOp(CstrReshapableOp, OpLoweringContext) {
731   // This op has no expression in the legacy export format.
732   return failure();
733 }
734 
ExportXlaOp(AddDependencyOp op,OpLoweringContext ctx)735 LogicalResult ExportXlaOp(AddDependencyOp op, OpLoweringContext ctx) {
736   auto& value_map = *ctx.values;
737   xla::XlaOp token;
738   xla::XlaOp operand;
739   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
740   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
741   auto operand_shape = ctx.builder->GetShape(operand).value();
742   value_map[op] = xla::internal::XlaBuilderFriend::BuildAddDependency(
743       ctx.builder, operand, token, operand_shape);
744   return success();
745 }
746 
ExportXlaOp(AllGatherOp op,OpLoweringContext ctx)747 LogicalResult ExportXlaOp(AllGatherOp op, OpLoweringContext ctx) {
748   auto& value_map = *ctx.values;
749   xla::XlaOp operand;
750   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
751   TensorType operand_type = op.operand().getType().cast<TensorType>();
752   TensorType result_type = op.getType();
753   if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
754     return failure();
755   auto all_gather_dim = op.all_gather_dim();
756   int64_t shard_count = result_type.getDimSize(all_gather_dim) /
757                         operand_type.getDimSize(all_gather_dim);
758   value_map[op] = xla::AllGather(operand, all_gather_dim, shard_count,
759                                  Convert_replica_groups(op.replica_groups()),
760                                  Convert_channel_handle(op.channel_handle()));
761   return success();
762 }
763 
ExportXlaOp(AllReduceOp op,OpLoweringContext ctx)764 LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
765   auto& value_map = *ctx.values;
766   xla::XlaComputation computation;
767   if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
768                                                      &computation))) {
769     return failure();
770   }
771 
772   xla::XlaOp operand;
773   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
774 
775   value_map[op] = xla::AllReduce(
776       operand, computation, Convert_replica_groups(op.replica_groups()),
777       Convert_channel_handle(op.channel_handle()), std::nullopt,
778       Convert_use_global_device_ids(op.use_global_device_ids()));
779   return success();
780 }
781 
ExportXlaOp(ReduceScatterOp op,OpLoweringContext ctx)782 LogicalResult ExportXlaOp(ReduceScatterOp op, OpLoweringContext ctx) {
783   auto& value_map = *ctx.values;
784   xla::XlaOp operand;
785   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
786   TensorType operand_type = op.operand().getType().cast<TensorType>();
787   TensorType result_type = op.getType();
788   if (!operand_type.hasStaticShape() || !result_type.hasStaticShape())
789     return failure();
790   auto scatter_dim = op.scatter_dimension();
791   int64_t shard_count = operand_type.getDimSize(scatter_dim) /
792                         result_type.getDimSize(scatter_dim);
793 
794   xla::XlaComputation computation;
795   if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
796                                                      &computation))) {
797     return failure();
798   }
799 
800   value_map[op] =
801       xla::ReduceScatter(operand, computation, scatter_dim, shard_count,
802                          Convert_replica_groups(op.replica_groups()),
803                          Convert_channel_handle(op.channel_handle()));
804   return success();
805 }
806 
ExportXlaOp(BitcastConvertOp op,OpLoweringContext ctx)807 LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
808   auto& value_map = *ctx.values;
809   xla::XlaOp operand;
810   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
811 
812   value_map[op] = xla::BitcastConvertType(
813       operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
814   return success();
815 }
816 
ExportXlaOp(BroadcastInDimOp op,OpLoweringContext ctx)817 LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
818   auto type = op.getType().dyn_cast<RankedTensorType>();
819   if (!type) return failure();
820   auto& value_map = *ctx.values;
821   xla::XlaOp operand;
822   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
823 
824   value_map[op] =
825       BroadcastInDim(operand, Convert_ArrayRef(type.getShape()),
826                      Convert_broadcast_dimensions(op.broadcast_dimensions()));
827   return success();
828 }
829 
ExportXlaOp(CosineOp op,OpLoweringContext ctx)830 LogicalResult ExportXlaOp(CosineOp op, OpLoweringContext ctx) {
831   auto& value_map = *ctx.values;
832   auto result = op.getResult();
833   xla::XlaOp arg;
834   if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op)))
835     return mlir::failure();
836   auto xla_result = xla::Cos(Unwrap(arg));
837   value_map[result] = xla_result;
838   return mlir::success();
839 }
840 
ExportXlaOp(DotOp op,OpLoweringContext ctx)841 LogicalResult ExportXlaOp(DotOp op, OpLoweringContext ctx) {
842   auto& value_map = *ctx.values;
843   xla::XlaOp lhs, rhs;
844   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
845   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
846   xla::PrimitiveType preferred_element_type =
847       xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
848   value_map[op] = xla::Dot(
849       lhs, rhs, Unwrap(Convert_precision_config(op.precision_config())),
850       preferred_element_type);
851   return mlir::success();
852 }
853 
ExportXlaOp(DotGeneralOp op,OpLoweringContext ctx)854 LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) {
855   auto& value_map = *ctx.values;
856   xla::XlaOp lhs, rhs;
857   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
858   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
859   xla::PrimitiveType preferred_element_type =
860       xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
861   value_map[op] = xla::DotGeneral(
862       lhs, rhs, Convert_dot_dimension_numbers(op.dot_dimension_numbers()),
863       Unwrap(Convert_precision_config(op.precision_config())),
864       preferred_element_type);
865   return mlir::success();
866 }
867 
ExportXlaOp(DomainOp op,OpLoweringContext ctx)868 LogicalResult ExportXlaOp(DomainOp op, OpLoweringContext ctx) {
869   auto& valueMap = *ctx.values;
870 
871   xla::Shape shape = xla::TypeToShape(op.getResult().getType());
872   xla::XlaOp operand;
873   if (failed(GetXlaOp(op.operand(), valueMap, &operand, op))) return failure();
874 
875   auto entry = CreateOpShardingFromStringRef(op.entry_metadata());
876   if (!entry) return failure();
877   auto exit = CreateOpShardingFromStringRef(op.exit_metadata());
878   if (!exit) return failure();
879 
880   valueMap[op] = xla::internal::XlaBuilderFriend::BuildDomain(
881       ctx.builder, operand, *exit, *entry, shape);
882   return success();
883 }
884 
ExportXlaOp(DynamicBroadcastInDimOp op,OpLoweringContext ctx)885 LogicalResult ExportXlaOp(DynamicBroadcastInDimOp op, OpLoweringContext ctx) {
886   // This op has no expression in the legacy export format.
887   return failure();
888 }
889 
ExportXlaOp(DynamicIotaOp op,OpLoweringContext ctx)890 LogicalResult ExportXlaOp(DynamicIotaOp op, OpLoweringContext ctx) {
891   // This op has no expression in the legacy export format.
892   return failure();
893 }
894 
ExportXlaOp(DynamicReshapeOp op,OpLoweringContext ctx)895 LogicalResult ExportXlaOp(DynamicReshapeOp op, OpLoweringContext ctx) {
896   // This op has no expression in the legacy export format.
897   return failure();
898 }
899 
ExportXlaOp(IfOp op,OpLoweringContext ctx)900 LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
901   xla::XlaComputation true_branch;
902   xla::XlaComputation false_branch;
903   auto& value_map = *ctx.values;
904 
905   // mhlo.IfOp does not have any operands or blocks-arguments. The computation
906   // inside the region-blocks use implicit captures of values defined above.
907   // In order to create the xla parameters for functions corresponding to
908   // IfOp regions, we need to infer the a region-block's arguments, using all
909   // the values used in the region but defined above. Note that in case there
910   // are zero implicit capture for a region, we use an empty tuple as the xla
911   // parameter.
912   //
913   // Note that the implicit values used in true and false branch regions might
914   // be different and, as a result, the xla parameters for the corresponding
915   // regions could have different shapes.
916   llvm::SetVector<mlir::Value> implicit_true_operand_set,
917       implicit_false_operand_set;
918   getUsedValuesDefinedAbove(op.true_branch(), op.true_branch(),
919                             implicit_true_operand_set);
920   getUsedValuesDefinedAbove(op.false_branch(), op.false_branch(),
921                             implicit_false_operand_set);
922 
923   llvm::SmallVector<mlir::Value> implicit_true_operands(
924       implicit_true_operand_set.begin(), implicit_true_operand_set.end());
925   llvm::SmallVector<mlir::Value> implicit_false_operands(
926       implicit_false_operand_set.begin(), implicit_false_operand_set.end());
927 
928   // Create xla parameters for functions corresponding to ifOp regions using the
929   // implicit captures operands. Also export the instructions within those
930   // regions.
931   if (failed(ctx.converter->LowerRegionAsComputation(
932           &op.true_branch(), &true_branch,
933           llvm::makeArrayRef(implicit_true_operands),
934           /*ensure_single_arg*/ true)) ||
935       failed(ctx.converter->LowerRegionAsComputation(
936           &op.false_branch(), &false_branch,
937           llvm::makeArrayRef(implicit_false_operands),
938           /*ensure_single_arg*/ true))) {
939     return failure();
940   }
941 
942   // Create the Xla pred argument.
943   xla::XlaOp pred;
944   if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure();
945 
946   // Create the true branch Xla argument.
947   llvm::SmallVector<xla::XlaOp> true_args;
948   if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args)))
949     return failure();
950   xla::XlaOp true_arg =
951       true_args.size() == 1 ? true_args[0] : Tuple(ctx.builder, true_args);
952 
953   // Create the false branch Xla argument.
954   llvm::SmallVector<xla::XlaOp> false_args;
955   if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args)))
956     return failure();
957   xla::XlaOp false_arg =
958       false_args.size() == 1 ? false_args[0] : Tuple(ctx.builder, false_args);
959 
960   // Create XLA Conditional op.
961   auto ifop =
962       xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch);
963 
964   // mhlo.IfOp have multiple returns, untuple all the results of XLA's.
965   if (op.getNumResults() == 1) {
966     value_map[op.getResult(0)] = ifop;
967   } else {
968     for (const auto& item : llvm::enumerate(op.getResults())) {
969       value_map[item.value()] = xla::GetTupleElement(ifop, item.index());
970     }
971   }
972 
973   return success();
974 }
975 
ExportXlaOp(CaseOp op,OpLoweringContext ctx)976 LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
977   llvm::DenseMap<mlir::Value, xla::XlaOp>& value_map = *ctx.values;
978   // OperandRange operands = op.branch_operands();
979   MutableArrayRef<Region> branches = op.branches();
980   llvm::SmallVector<xla::XlaOp, 4> branch_operands(branches.size());
981   std::vector<xla::XlaComputation> computations(branches.size());
982   std::vector<xla::XlaComputation*> computations_p(branches.size());
983 
984   // mhlo.CaseOp does not have any operands or blocks-arguments. The computation
985   // inside the region-blocks use implicit captures of values defined above.
986   // In order to create the xla parameters for functions corresponding to
987   // CaseOp regions, we need to infer the a region-block's arguments, using all
988   // the values used in the region but defined above. Note that in case there
989   // are zero implicit captures for a region, we use an empty tuple as the xla
990   // parameter.
991   //
992   // Note that the implicit values used in the regions might
993   // be different and, as a result, the xla parameters for the corresponding
994   // regions could have different shapes.
995   for (unsigned i = 0; i < branches.size(); ++i) {
996     llvm::SetVector<mlir::Value> implicit_operand_set;
997     getUsedValuesDefinedAbove(branches[i], branches[i], implicit_operand_set);
998     llvm::SmallVector<mlir::Value> implicit_operands(
999         implicit_operand_set.begin(), implicit_operand_set.end());
1000 
1001     // Create the branches[i]'s Xla argument.
1002     llvm::SmallVector<xla::XlaOp> args;
1003     if (failed(GetXlaOps(op, implicit_operands, ctx, args))) return failure();
1004     branch_operands[i] = args.size() == 1 ? args[0] : Tuple(ctx.builder, args);
1005 
1006     // Create xla parameters for functions corresponding to region branches[i]
1007     // using the implicit captures operands. Also export the instructions within
1008     // that region.
1009     computations_p[i] = &computations[i];
1010     if (failed(ctx.converter->LowerRegionAsComputation(
1011             &branches[i], computations_p[i],
1012             llvm::makeArrayRef(implicit_operands),
1013             /*ensure_single_arg*/ true)))
1014       return failure();
1015   }
1016 
1017   xla::XlaOp index;
1018   if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure();
1019 
1020   xla::XlaOp caseop = xla::Conditional(index, computations_p, branch_operands);
1021 
1022   // mhlo.CaseOp have multiple returns, untuple all the results of XLA's.
1023   if (op.getNumResults() == 1) {
1024     value_map[op.getResult(0)] = caseop;
1025   } else {
1026     for (const auto& item : llvm::enumerate(op.getResults())) {
1027       value_map[item.value()] = xla::GetTupleElement(caseop, item.index());
1028     }
1029   }
1030   return success();
1031 }
1032 
1033 // Specialize CompareOp export to set broadcast_dimensions argument.
ExportXlaOp(mlir::mhlo::CompareOp op,OpLoweringContext ctx)1034 mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op,
1035                                 OpLoweringContext ctx) {
1036   auto& value_map = *ctx.values;
1037   xla::XlaOp lhs, rhs;
1038   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
1039   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
1040   auto dir = Convert_comparison_direction(
1041       mlir::mhlo::stringifyComparisonDirection(op.comparison_direction()));
1042   auto type_attr = op.compare_typeAttr();
1043 
1044   xla::XlaOp xla_result;
1045   if (type_attr && type_attr.getValue() != mlir::mhlo::ComparisonType::NOTYPE) {
1046     auto type = xla::StringToComparisonType(
1047                     stringifyComparisonType(type_attr.getValue()).str())
1048                     .ValueOrDie();
1049     xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type);
1050   } else {
1051     xla_result = xla::Compare(lhs, rhs, dir);
1052   }
1053   value_map[op] = xla_result;
1054   return mlir::success();
1055 }
1056 
ExportXlaOp(ConstantOp op,OpLoweringContext ctx)1057 LogicalResult ExportXlaOp(ConstantOp op, OpLoweringContext ctx) {
1058   return failure();
1059 }
1060 
ExportXlaOp(mlir::mhlo::ConvolutionOp op,OpLoweringContext ctx)1061 LogicalResult ExportXlaOp(mlir::mhlo::ConvolutionOp op, OpLoweringContext ctx) {
1062   auto& value_map = *ctx.values;
1063   xla::XlaOp lhs, rhs;
1064   if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
1065   if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
1066   xla::PrimitiveType preferred_element_type =
1067       xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()));
1068   xla::XlaOp xla_result = xla::ConvGeneralDilated(
1069       lhs, rhs, Convert_window_strides(op.window_strides()),
1070       Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
1071       Convert_rhs_dilation(op.rhs_dilation()),
1072       xla::ConvertConvDimensionNumbers(op.dimension_numbers()),
1073       Convertuint64_t(op.feature_group_count()),
1074       Convertuint64_t(op.batch_group_count()),
1075       Unwrap(Convert_precision_config(op.precision_config())),
1076       preferred_element_type, Convert_window_reversal(op.window_reversal()));
1077   value_map[op] = xla_result;
1078   return mlir::success();
1079 }
1080 
ExportXlaOp(ConvertOp op,OpLoweringContext ctx)1081 LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
1082   auto& value_map = *ctx.values;
1083   xla::XlaOp operand;
1084   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1085 
1086   value_map[op] = xla::ConvertElementType(
1087       operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
1088   return success();
1089 }
1090 
ExportXlaOp(CustomCallOp op,OpLoweringContext ctx)1091 LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
1092   if (op.getNumResults() != 1)
1093     return op.emitOpError() << "with multiple results cannot be exported";
1094 
1095   if (op.called_computations().size() > 1)
1096     return op.emitOpError()
1097            << "cannot export with more than one called computations";
1098 
1099   // Custom call can be exported either with called computation or with layout
1100   // attributes. The XlaBuilder API does not allow both.
1101   if (!op.called_computations().empty() && op.operand_layouts() &&
1102       op.result_layouts()) {
1103     return op.emitOpError() << "cannot export if both called computation and "
1104                                "layouts are specified";
1105   }
1106 
1107   Value result = op.getResult(0);
1108   llvm::SmallVector<xla::XlaOp> args;
1109   if (failed(GetTuple(op, op.operands(), ctx, args))) return failure();
1110   auto xla_api_version = xla::ConvertCustomCallApiVersion(op.api_version());
1111   if (!xla_api_version.ok()) return failure();
1112   auto& value_map = *ctx.values;
1113 
1114   if (op.called_computations().size() == 1) {
1115     mlir::func::FuncOp callee = ctx.converter->LookUpSymbol(
1116         op.called_computations()[0].cast<FlatSymbolRefAttr>());
1117     if (failed(ctx.converter->RunOnFunction(callee))) return failure();
1118     xla::XlaComputation& computation =
1119         ctx.converter->GetLoweredComputation(callee);
1120     value_map[result] = xla::CustomCallWithComputation(
1121         ctx.builder, std::string(op.call_target_name()), args, computation,
1122         xla::TypeToShape(result.getType()), std::string(op.backend_config()),
1123         op.has_side_effect(),
1124         /*output_operand_aliasing=*/{},
1125         /*literal=*/nullptr,
1126         /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
1127         /*api_version=*/*xla_api_version);
1128     return success();
1129   }
1130 
1131   if (op.operand_layouts() && op.result_layouts()) {
1132     auto operand_shapes_with_layout = ConvertTypesToShapesWithLayout(
1133         op.getOperandTypes(), op.operand_layouts().getValue());
1134     xla::Shape result_shape_with_layout = GetCustomCallResultShapeWithLayout(
1135         result.getType(), op.result_layouts().getValue());
1136     value_map[result] = xla::CustomCallWithLayout(
1137         ctx.builder, std::string(op.call_target_name()), args,
1138         result_shape_with_layout, operand_shapes_with_layout,
1139         std::string(op.backend_config()), op.has_side_effect(),
1140         /*output_operand_aliasing=*/{},
1141         /*literal=*/nullptr,
1142         /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
1143         /*api_version=*/*xla_api_version);
1144     return success();
1145   }
1146 
1147   value_map[result] = xla::CustomCall(
1148       ctx.builder, std::string(op.call_target_name()), args,
1149       xla::TypeToShape(result.getType()), std::string(op.backend_config()),
1150       op.has_side_effect(), /*output_operand_aliasing=*/{},
1151       /*literal=*/nullptr,
1152       /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
1153       /*api_version=*/*xla_api_version);
1154   return success();
1155 }
1156 
ExportXlaOp(InfeedOp op,OpLoweringContext ctx)1157 LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
1158   auto& value_map = *ctx.values;
1159   xla::XlaOp token;
1160   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1161 
1162   // mhlo.infeed produces multiple results. The shape argument expected by the
1163   // xla client API is a tuple type with two element-types:
1164   // data_type : A tuple containing all the mhlo.infeedOp result types except
1165   //             the token type.
1166   // token_type : The last result type of mhlo.infeedOp.
1167   auto result_types = op.getResultTypes();
1168   auto num_results = op.getNumResults();
1169 
1170   xla::Shape token_shape = xla::TypeToShape(result_types[num_results - 1]);
1171   std::vector<xla::Shape> subshapes;
1172   for (const auto& item : llvm::enumerate(result_types)) {
1173     if (item.index() == num_results - 1) break;
1174     subshapes.push_back(xla::TypeToShape(item.value()));
1175   }
1176 
1177   xla::Shape data_shape = xla::ShapeUtil::MakeTupleShape(subshapes);
1178   auto xla_result =
1179       xla::InfeedWithToken(token, data_shape, std::string(op.infeed_config()));
1180   ctx.builder->ClearSharding();
1181 
1182   if (!subshapes.empty()) {
1183     auto data_tuple_element = xla::GetTupleElement(xla_result, 0);
1184     for (const auto& item : llvm::enumerate(op.getResults())) {
1185       if (item.index() == num_results - 1) break;
1186       value_map[item.value()] =
1187           xla::GetTupleElement(data_tuple_element, item.index());
1188     }
1189   }
1190 
1191   value_map[op.getResult(num_results - 1)] =
1192       xla::GetTupleElement(xla_result, 1);
1193 
1194   return success();
1195 }
1196 
ExportXlaOp(IotaOp op,OpLoweringContext ctx)1197 LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
1198   auto& value_map = *ctx.values;
1199   value_map[op] = xla::Iota(ctx.builder, xla::TypeToShape(op.getType()),
1200                             op.iota_dimension());
1201   return success();
1202 }
1203 
ExportXlaOp(MapOp op,OpLoweringContext ctx)1204 LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) {
1205   auto& value_map = *ctx.values;
1206   xla::XlaComputation computation;
1207   if (failed(ctx.converter->LowerRegionAsComputation(&op.computation(),
1208                                                      &computation))) {
1209     return failure();
1210   }
1211   llvm::SmallVector<xla::XlaOp> operands;
1212   if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1213   value_map[op] = xla::Map(ctx.builder, operands, computation,
1214                            Convert_dimensions(op.dimensions()));
1215   return success();
1216 }
1217 
ExportXlaOp(OutfeedOp op,OpLoweringContext ctx)1218 LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
1219   auto& value_map = *ctx.values;
1220 
1221   llvm::SmallVector<xla::XlaOp> operands;
1222   if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1223 
1224   xla::XlaOp operand = Tuple(ctx.builder, operands);
1225 
1226   std::vector<xla::Shape> subshapes;
1227   for (auto operand : op.operands())
1228     subshapes.push_back(xla::TypeToShape(operand.getType()));
1229 
1230   xla::Shape shape_with_layout = xla::ShapeUtil::MakeTupleShape(subshapes);
1231 
1232   xla::XlaOp token;
1233   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1234 
1235   value_map[op] = xla::OutfeedWithToken(operand, token, shape_with_layout,
1236                                         std::string(op.outfeed_config()));
1237   return success();
1238 }
1239 
ExportXlaOp(PartitionIdOp op,OpLoweringContext ctx)1240 LogicalResult ExportXlaOp(PartitionIdOp op, OpLoweringContext ctx) {
1241   auto& value_map = *ctx.values;
1242   xla::Shape shape = xla::TypeToShape(op.getResult().getType());
1243   value_map[op] =
1244       xla::internal::XlaBuilderFriend::BuildPartitionId(ctx.builder, shape);
1245   return success();
1246 }
1247 
ExportXlaOp(PadOp op,OpLoweringContext ctx)1248 LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
1249   auto& value_map = *ctx.values;
1250   xla::PaddingConfig padding_config;
1251   auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
1252   auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
1253   auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
1254   for (int64_t i = 0, end = edge_padding_low.size(); i < end; ++i) {
1255     auto* dims = padding_config.add_dimensions();
1256     dims->set_edge_padding_low(edge_padding_low[i]);
1257     dims->set_edge_padding_high(edge_padding_high[i]);
1258     dims->set_interior_padding(interior_padding[i]);
1259   }
1260   xla::XlaOp operand, padding_value;
1261   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1262   if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op)))
1263     return failure();
1264 
1265   value_map[op] = xla::Pad(operand, padding_value, padding_config);
1266   return success();
1267 }
1268 
ExportXlaOp(RecvOp op,OpLoweringContext ctx)1269 LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) {
1270   auto& value_map = *ctx.values;
1271 
1272   xla::XlaOp token;
1273   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1274 
1275   // mhlo.recvOp produces multiple results. The shape argument expected by the
1276   // xla client API is a tuple type with two element-types:
1277   // data_type : A tuple containing all the mhlo.RecvOp result types except
1278   //             the token type.
1279   // token_type : The last result type of mhlo.recvOp.
1280   auto result_types = op.getResultTypes();
1281   auto num_results = op.getNumResults();
1282 
1283   xla::Shape token_shape = xla::TypeToShape(result_types[num_results - 1]);
1284   std::vector<xla::Shape> subshapes;
1285   for (const auto& item : llvm::enumerate(result_types)) {
1286     if (item.index() == num_results - 1) break;
1287     subshapes.push_back(xla::TypeToShape(item.value()));
1288   }
1289 
1290   xla::Shape data_shape;
1291   if (subshapes.size() == 1)
1292     data_shape = subshapes[0];
1293   else
1294     data_shape = xla::ShapeUtil::MakeTupleShape(subshapes);
1295 
1296   xla::XlaOp xla_result;
1297   if (op.is_host_transfer()) {
1298     xla_result = xla::RecvFromHost(token, data_shape,
1299                                    Convert_channel_handle(op.channel_handle()));
1300   } else {
1301     xla_result = xla::RecvWithToken(
1302         token, data_shape, Convert_channel_handle(op.channel_handle()));
1303   }
1304 
1305   auto data_tuple_element = xla::GetTupleElement(xla_result, 0);
1306   if (subshapes.size() == 1) {
1307     value_map[op.getResult(0)] = data_tuple_element;
1308   } else {
1309     for (const auto& item : llvm::enumerate(op.getResults())) {
1310       if (item.index() == num_results - 1) break;
1311       value_map[item.value()] =
1312           xla::GetTupleElement(data_tuple_element, item.index());
1313     }
1314   }
1315 
1316   value_map[op.getResult(num_results - 1)] =
1317       xla::GetTupleElement(xla_result, 1);
1318 
1319   return success();
1320 }
1321 
ExportXlaOp(ReduceOp op,OpLoweringContext ctx)1322 LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
1323   auto& value_map = *ctx.values;
1324   xla::XlaComputation body;
1325   if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
1326     return failure();
1327   }
1328   llvm::SmallVector<xla::XlaOp> operands, init_values;
1329   if (failed(GetTuple(op, op.operands(), ctx, operands)) ||
1330       failed(GetTuple(op, op.init_values(), ctx, init_values))) {
1331     return failure();
1332   }
1333   xla::XlaOp result =
1334       xla::Reduce(ctx.builder, operands, init_values, body,
1335                   Convert_broadcast_dimensions(op.dimensions()));
1336   if (op.getNumResults() == 1) {
1337     value_map[op.getResult(0)] = result;
1338   } else {
1339     for (const auto& item : llvm::enumerate(op.getResults())) {
1340       value_map[item.value()] = xla::GetTupleElement(result, item.index());
1341     }
1342   }
1343   return success();
1344 }
1345 
ExportXlaOp(ReduceWindowOp op,OpLoweringContext ctx)1346 LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) {
1347   auto& value_map = *ctx.values;
1348   xla::XlaComputation body;
1349   if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) {
1350     return failure();
1351   }
1352   llvm::SmallVector<xla::XlaOp> operands, init_values;
1353   if (failed(GetTuple(op, op.operands(), ctx, operands)) ||
1354       failed(GetTuple(op, op.init_values(), ctx, init_values))) {
1355     return failure();
1356   }
1357 
1358   xla::XlaOp result = xla::ReduceWindowWithGeneralPadding(
1359       operands, init_values, body, ConvertDenseIntAttr(op.window_dimensions()),
1360       ConvertDenseIntAttr(op.window_strides()),
1361       ConvertDenseIntAttr(op.base_dilations()),
1362       ConvertDenseIntAttr(op.window_dilations()),
1363       Convert_padding(op.padding()));
1364 
1365   if (op.getNumResults() == 1) {
1366     value_map[op.getResult(0)] = result;
1367   } else {
1368     for (const auto& item : llvm::enumerate(op.getResults())) {
1369       value_map[item.value()] = xla::GetTupleElement(result, item.index());
1370     }
1371   }
1372   return success();
1373 }
1374 
ExportXlaOp(ReshapeOp op,OpLoweringContext ctx)1375 LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
1376   auto& value_map = *ctx.values;
1377   xla::XlaOp operand;
1378   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1379 
1380   value_map[op] =
1381       xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions());
1382   return success();
1383 }
1384 
ExportXlaOp(ReturnOp op,OpLoweringContext ctx)1385 LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
1386   // Failure on purpose because `mhlo::ReturnOp` will be handled by
1387   // special purpose logic in `ConvertToHloModule::Lower`.
1388   return failure();
1389 }
1390 
ExportXlaOp(RngBitGeneratorOp op,OpLoweringContext ctx)1391 LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) {
1392   auto& value_map = *ctx.values;
1393   auto results = op.getResults();
1394   auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()];
1395   auto xla_result = xla::RngBitGenerator(
1396       static_cast<xla::RandomAlgorithm>(op.rng_algorithm()), Unwrap(xla_arg_1),
1397       xla::TypeToShape(results[1].getType()));
1398 
1399   for (const auto& item : llvm::enumerate(results))
1400     value_map[item.value()] = xla::GetTupleElement(xla_result, item.index());
1401 
1402   return mlir::success();
1403 }
1404 
ExportXlaOp(XlaRngGetAndUpdateStateOp op,OpLoweringContext ctx)1405 LogicalResult ExportXlaOp(XlaRngGetAndUpdateStateOp op, OpLoweringContext ctx) {
1406   // This op does not exist in the XLA builder interface.
1407   (*ctx.values)[op.getResult()] =
1408       xla::internal::XlaBuilderFriend::BuildRngGetAndUpdateState(
1409           ctx.builder, static_cast<int64_t>(op.delta()),
1410           xla::TypeToShape(op.getType()));
1411   return mlir::success();
1412 }
1413 
ExportXlaOp(BatchNormGradOp op,OpLoweringContext ctx)1414 LogicalResult ExportXlaOp(BatchNormGradOp op, OpLoweringContext ctx) {
1415   auto& value_map = *ctx.values;
1416   auto results = op.getResults();
1417 
1418   xla::XlaOp operand, scale, mean, variance, grad_output;
1419   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1420   if (failed(GetXlaOp(op.scale(), value_map, &scale, op))) return failure();
1421   if (failed(GetXlaOp(op.mean(), value_map, &mean, op))) return failure();
1422   if (failed(GetXlaOp(op.variance(), value_map, &variance, op)))
1423     return failure();
1424   if (failed(GetXlaOp(op.grad_output(), value_map, &grad_output, op)))
1425     return failure();
1426 
1427   auto xla_result =
1428       xla::BatchNormGrad(operand, scale, mean, variance, grad_output,
1429                          ConvertAPFloat(op.epsilon()), op.feature_index());
1430 
1431   for (const auto& item : llvm::enumerate(results))
1432     value_map[item.value()] = xla::GetTupleElement(xla_result, item.index());
1433 
1434   return mlir::success();
1435 }
1436 
ExportXlaOp(BatchNormTrainingOp op,OpLoweringContext ctx)1437 LogicalResult ExportXlaOp(BatchNormTrainingOp op, OpLoweringContext ctx) {
1438   auto& value_map = *ctx.values;
1439   auto results = op.getResults();
1440 
1441   xla::XlaOp operand, scale, offset;
1442   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1443   if (failed(GetXlaOp(op.scale(), value_map, &scale, op))) return failure();
1444   if (failed(GetXlaOp(op.offset(), value_map, &offset, op))) return failure();
1445 
1446   auto xla_result = xla::BatchNormTraining(
1447       operand, scale, offset, ConvertAPFloat(op.epsilon()), op.feature_index());
1448 
1449   for (const auto& item : llvm::enumerate(results))
1450     value_map[item.value()] = xla::GetTupleElement(xla_result, item.index());
1451 
1452   return mlir::success();
1453 }
1454 
ExportXlaOp(RngOp op,OpLoweringContext ctx)1455 LogicalResult ExportXlaOp(RngOp op, OpLoweringContext ctx) {
1456   auto& value_map = *ctx.values;
1457   xla::XlaOp a, b;
1458   if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure();
1459   if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure();
1460 
1461   if (op.rng_distribution() == RngDistribution::UNIFORM) {
1462     value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType()));
1463     return success();
1464   } else if (op.rng_distribution() == RngDistribution::NORMAL) {
1465     value_map[op] = xla::RngNormal(a, b, xla::TypeToShape(op.getType()));
1466     return success();
1467   }
1468   return failure();
1469 }
1470 
ExportXlaOp(ScatterOp op,OpLoweringContext ctx)1471 LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
1472   auto& value_map = *ctx.values;
1473   xla::XlaComputation update_computation;
1474   if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
1475                                                      &update_computation))) {
1476     return failure();
1477   }
1478   xla::ScatterDimensionNumbers dimension_numbers =
1479       Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
1480 
1481   llvm::SmallVector<xla::XlaOp> operands;
1482   llvm::SmallVector<xla::XlaOp> updates;
1483   if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1484   if (failed(GetTuple(op, op.updates(), ctx, updates))) return failure();
1485 
1486   xla::XlaOp scatter_indices;
1487   if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op)))
1488     return failure();
1489 
1490   auto scatter_op = xla::Scatter(operands, scatter_indices, updates,
1491                                  update_computation, dimension_numbers,
1492                                  op.indices_are_sorted(), op.unique_indices());
1493   if (op->getNumResults() == 1) {
1494     value_map[op.getResult(0)] = scatter_op;
1495     return success();
1496   }
1497 
1498   // mhlo.ScatterOp supports multiple returns, untuple all the results of XLA's.
1499   for (const auto& it : llvm::enumerate(op.getResults())) {
1500     value_map[it.value()] = xla::GetTupleElement(scatter_op, it.index());
1501   }
1502 
1503   return success();
1504 }
1505 
ExportXlaOp(SelectAndScatterOp op,OpLoweringContext ctx)1506 LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
1507   auto& value_map = *ctx.values;
1508   xla::XlaComputation select;
1509   xla::XlaComputation scatter;
1510   if (failed(ctx.converter->LowerRegionAsComputation(&op.select(), &select)) ||
1511       failed(
1512           ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) {
1513     return failure();
1514   }
1515   xla::XlaOp operand, source, init_value;
1516   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1517   if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure();
1518   if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op)))
1519     return failure();
1520 
1521   value_map[op] = xla::SelectAndScatterWithGeneralPadding(
1522       operand, select, ConvertDenseIntAttr(op.window_dimensions()),
1523       ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()),
1524       source, init_value, scatter);
1525   return success();
1526 }
1527 
ExportXlaOp(SendOp op,OpLoweringContext ctx)1528 LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) {
1529   auto& value_map = *ctx.values;
1530 
1531   llvm::SmallVector<xla::XlaOp> operands;
1532   if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1533 
1534   xla::XlaOp operand;
1535   if (operands.size() == 1)
1536     operand = operands[0];
1537   else
1538     operand = Tuple(ctx.builder, operands);
1539 
1540   xla::XlaOp token;
1541   if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure();
1542 
1543   if (op.is_host_transfer()) {
1544     value_map[op] = xla::SendToHost(
1545         operand, token, operand.builder()->GetShape(operand).value(),
1546         Convert_channel_handle(op.channel_handle()));
1547     return success();
1548   }
1549   value_map[op] = xla::SendWithToken(
1550       operand, token, Convert_channel_handle(op.channel_handle()));
1551   return success();
1552 }
1553 
ExportXlaOp(mlir::mhlo::SineOp op,OpLoweringContext ctx)1554 mlir::LogicalResult ExportXlaOp(mlir::mhlo::SineOp op, OpLoweringContext ctx) {
1555   auto& value_map = *ctx.values;
1556   auto result = op.getResult();
1557   xla::XlaOp arg;
1558   if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &arg, op)))
1559     return mlir::failure();
1560   auto xla_result = xla::Sin(Unwrap(arg));
1561   value_map[result] = xla_result;
1562   return mlir::success();
1563 }
1564 
ExportXlaOp(SliceOp op,OpLoweringContext ctx)1565 LogicalResult ExportXlaOp(SliceOp op, OpLoweringContext ctx) {
1566   return failure();
1567 }
1568 
ExportXlaOp(SortOp op,OpLoweringContext ctx)1569 LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) {
1570   xla::XlaComputation comparator;
1571   if (failed(ctx.converter->LowerRegionAsComputation(&op.comparator(),
1572                                                      &comparator)))
1573     return failure();
1574 
1575   llvm::SmallVector<xla::XlaOp> operands;
1576   if (failed(GetTuple(op, op.operands(), ctx, operands))) return failure();
1577   auto sorted = xla::Sort(operands, comparator, op.dimension(), op.is_stable());
1578 
1579   auto& value_map = *ctx.values;
1580   auto shape_or = sorted.builder()->GetShape(sorted);
1581   if (!shape_or.ok()) {
1582     return op.emitError(shape_or.status().ToString());
1583   }
1584 
1585   xla::Shape& shape = shape_or.ValueOrDie();
1586   if (!shape.IsTuple()) {
1587     value_map[op.getResult(0)] = sorted;
1588     return success();
1589   }
1590 
1591   // MLIR's sort supports multiple returns, untuple all the results of XLA's.
1592   for (const auto& it : llvm::enumerate(op.getResults())) {
1593     value_map[it.value()] = xla::GetTupleElement(sorted, it.index());
1594   }
1595   return success();
1596 }
1597 
ExportXlaOp(SubtractOp op,OpLoweringContext ctx)1598 LogicalResult ExportXlaOp(SubtractOp op, OpLoweringContext ctx) {
1599   auto& value_map = *ctx.values;
1600   auto result = op.getResult();
1601   xla::XlaOp lhs;
1602   if (failed(GetXlaOp(*op.getODSOperands(0).begin(), value_map, &lhs, op)))
1603     return mlir::failure();
1604   xla::XlaOp rhs;
1605   if (failed(GetXlaOp(*op.getODSOperands(1).begin(), value_map, &rhs, op)))
1606     return mlir::failure();
1607   auto xla_result = xla::Sub(Unwrap(lhs), Unwrap(rhs));
1608   value_map[result] = xla_result;
1609   return mlir::success();
1610 }
1611 
ExportXlaOp(TraceOp op,OpLoweringContext ctx)1612 LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) {
1613   // TODO(atondwal): remove mhlo.trace
1614   return success();
1615 }
1616 
ExportXlaOp(UnaryEinsumOp op,OpLoweringContext ctx)1617 LogicalResult ExportXlaOp(UnaryEinsumOp op, OpLoweringContext ctx) {
1618   // Intentional as UnaryEinsumOp is always lowered to the EinsumOp with two
1619   // operands.
1620   return failure();
1621 }
1622 
ExportXlaOp(WhileOp op,OpLoweringContext ctx)1623 LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) {
1624   xla::XlaComputation condition;
1625   xla::XlaComputation body;
1626   if (failed(ctx.converter->LowerRegionAsComputation(
1627           &op.body(), &body, llvm::None, /*ensure_single_arg*/ true)) ||
1628       failed(ctx.converter->LowerRegionAsComputation(
1629           &op.cond(), &condition, llvm::None, /*ensure_single_arg*/ true))) {
1630     return failure();
1631   }
1632 
1633   // In case MHLO's whileOp has multiple operands, create xla::Tuple, using
1634   // those operands, to be used as sole operand of xla::While.
1635   llvm::SmallVector<xla::XlaOp> operands;
1636   if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure();
1637 
1638   xla::XlaOp operand = operands[0];
1639   if (operands.size() > 1) operand = Tuple(ctx.builder, operands);
1640 
1641   auto whileop = xla::While(condition, body, operand);
1642 
1643   auto& value_map = *ctx.values;
1644   auto shape_or = whileop.builder()->GetShape(whileop);
1645   if (!shape_or.ok()) {
1646     return op.emitError(shape_or.status().ToString());
1647   }
1648 
1649   xla::Shape& shape = shape_or.ValueOrDie();
1650   if (!shape.IsTuple()) {
1651     value_map[op.getResult(0)] = whileop;
1652     return success();
1653   }
1654 
1655   // mhlo.WhileOp supports multiple returns, untuple all the results of XLA's.
1656   for (const auto& it : llvm::enumerate(op.getResults())) {
1657     value_map[it.value()] = xla::GetTupleElement(whileop, it.index());
1658   }
1659 
1660   return success();
1661 }
1662 
ExportXlaOp(OptimizationBarrierOp op,OpLoweringContext ctx)1663 LogicalResult ExportXlaOp(OptimizationBarrierOp op, OpLoweringContext ctx) {
1664   // In case MHLO's OptimizationBarrierOp has multiple operands,
1665   // create xla::Tuple, using those operands, to be used as
1666   // sole operand of xla::OptimizationBarrier.
1667   llvm::SmallVector<xla::XlaOp> operands;
1668   if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure();
1669   if (operands.empty()) return success();
1670 
1671   auto& value_map = *ctx.values;
1672   if (operands.size() == 1) {
1673     value_map[op.getResult(0)] = xla::OptimizationBarrier(operands[0]);
1674   } else {
1675     auto result = xla::OptimizationBarrier(Tuple(ctx.builder, operands));
1676 
1677     for (const auto& it : llvm::enumerate(op.getResults())) {
1678       value_map[it.value()] = xla::GetTupleElement(result, it.index());
1679     }
1680   }
1681 
1682   return success();
1683 }
1684 
ExportXlaOp(FusionOp op,OpLoweringContext ctx)1685 LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
1686   if (!op.fusion_kind()) {
1687     op.emitOpError() << "requires fusion kind for HLO translation";
1688     return failure();
1689   }
1690 
1691   xla::XlaComputation fused_computation;
1692   if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(),
1693                                                      &fused_computation)))
1694     return failure();
1695 
1696   auto& values = *ctx.values;
1697   llvm::SmallVector<xla::XlaOp, 4> operands;
1698   for (auto operand : op.operands()) operands.push_back(values[operand]);
1699 
1700   auto fusion_kind_string =
1701       mlir::mhlo::stringifyFusionKind(op.fusion_kind().getValue());
1702   xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion(
1703       ctx.builder, operands,
1704       absl::string_view(fusion_kind_string.data(), fusion_kind_string.size()),
1705       fused_computation);
1706   if (op.getNumResults() == 1) {
1707     values[op.getResult(0)] = fusion;
1708   } else {
1709     for (const auto& item : llvm::enumerate(op.getResults())) {
1710       values[item.value()] = xla::GetTupleElement(fusion, item.index());
1711     }
1712   }
1713   return success();
1714 }
1715 
ExportXlaOp(BitcastOp op,OpLoweringContext ctx)1716 LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
1717   auto& value_map = *ctx.values;
1718   xla::XlaOp operand;
1719   if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
1720   xla::XlaOp bitcast = xla::internal::XlaBuilderFriend::BuildBitcast(
1721       ctx.builder, operand, xla::TypeToShape(op.getType()));
1722   value_map[op] = bitcast;
1723   if (ctx.converter->GetOptions().propagate_bitcast_layouts_to_backend_config) {
1724     // Encode the source and result layout of the bitcast into the XLA HLO
1725     // backend config as a protobuf. Note that this is a temporary solution
1726     // which will go away once XLA:GPU stops falling back to XLA HLO Elemental
1727     // IR emitters.
1728     xla::HloInstructionProto* bitcast_proto =
1729         xla::internal::XlaBuilderFriend::GetInstruction(bitcast);
1730     xla::HloInstructionProto* operand_proto =
1731         xla::internal::XlaBuilderFriend::GetInstruction(operand);
1732     xla::LayoutProto result_layout =
1733         ExtractLayout(op, bitcast_proto->shape().dimensions_size(),
1734                       "result_layout")
1735             .ToProto();
1736     xla::LayoutProto source_layout =
1737         ExtractLayout(op, operand_proto->shape().dimensions_size(),
1738                       "source_layout")
1739             .ToProto();
1740     xla::gpu::BitcastBackendConfig bitcast_config;
1741     *bitcast_config.mutable_source_layout() = source_layout;
1742     *bitcast_config.mutable_result_layout() = result_layout;
1743     *bitcast_proto->mutable_backend_config() =
1744         bitcast_config.SerializeAsString();
1745   }
1746   return success();
1747 }
1748 
ExportXlaOp(RealDynamicSliceOp op,OpLoweringContext ctx)1749 LogicalResult ExportXlaOp(RealDynamicSliceOp op, OpLoweringContext ctx) {
1750   return failure();
1751 }
1752 
ExportXlaOp(DynamicPadOp op,OpLoweringContext ctx)1753 LogicalResult ExportXlaOp(DynamicPadOp op, OpLoweringContext ctx) {
1754   return failure();
1755 }
1756 
ExportXlaOp(DynamicGatherOp op,OpLoweringContext ctx)1757 LogicalResult ExportXlaOp(DynamicGatherOp op, OpLoweringContext ctx) {
1758   return failure();
1759 }
1760 
ExportXlaOp(DynamicConvOp op,OpLoweringContext ctx)1761 LogicalResult ExportXlaOp(DynamicConvOp op, OpLoweringContext ctx) {
1762   return failure();
1763 }
1764 
ExportXlaOp(UniformQuantizeOp op,OpLoweringContext ctx)1765 LogicalResult ExportXlaOp(UniformQuantizeOp op, OpLoweringContext ctx) {
1766   // Currently, it doesn't have an XLA builder equivalent.
1767   // TODO(b/230671877): Implement XLA import/export for quantized MHLO ops.
1768   return failure();
1769 }
1770 
ExportXlaOp(UniformDequantizeOp op,OpLoweringContext ctx)1771 LogicalResult ExportXlaOp(UniformDequantizeOp op, OpLoweringContext ctx) {
1772   // Currently, it doesn't have an XLA builder equivalent.
1773   // TODO(b/230671877): Implement XLA import/export for quantized MHLO ops.
1774   return failure();
1775 }
1776 
1777 }  // namespace
1778 }  // namespace mhlo
1779 }  // namespace mlir
1780 
1781 #include "tensorflow/compiler/mlir/xla/operator_writers.inc"
1782 
1783 namespace mlir {
1784 namespace {
1785 
CreateArrayLiteralFromAttr(ElementsAttr attr,xla::Layout layout)1786 StatusOr<xla::Literal> CreateArrayLiteralFromAttr(ElementsAttr attr,
1787                                                   xla::Layout layout) {
1788   auto dense_attr = attr.dyn_cast<DenseElementsAttr>();
1789   if (!dense_attr)
1790     return tensorflow::errors::Unimplemented(
1791         "Only dense elements attr are supported");
1792 
1793   xla::Shape shape = xla::TypeToShape(dense_attr.getType());
1794 
1795 #define ELEMENTS_ATTR_TO_LITERAL(xla_type, cpp_type)                         \
1796   case xla_type: {                                                           \
1797     xla::Array<cpp_type> source_data(shape.dimensions());                    \
1798     source_data.SetValues(dense_attr.getValues<cpp_type>());                 \
1799     return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, layout); \
1800   }
1801 
1802   switch (shape.element_type()) {
1803     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::PRED, bool)
1804     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F32, float)
1805     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F64, double)
1806     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S8, int8)
1807     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S16, int16)
1808     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S32, int32)
1809     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::S64, int64_t)
1810     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U8, uint8)
1811     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U16, uint16)
1812     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U32, uint32)
1813     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::U64, uint64)
1814     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C64, std::complex<float>)
1815     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::C128, std::complex<double>)
1816     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::F16, Eigen::half)
1817     ELEMENTS_ATTR_TO_LITERAL(xla::PrimitiveType::BF16, Eigen::bfloat16)
1818     default:
1819       return tensorflow::errors::Internal(absl::StrCat(
1820           "Unsupported type: ", xla::PrimitiveType_Name(shape.element_type())));
1821   }
1822 #undef ELEMENTS_ATTR_TO_LITERAL
1823 }
1824 
ConvertLayout(mlir::Operation * op,const mlir::ArrayAttr & layout,xla::ShapeProto * shape)1825 LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout,
1826                             xla::ShapeProto* shape) {
1827   // In the case of tuples, ShapeProtos can be nested, and so can the mlir
1828   // attribute describing the layout. So recurse into the subshapes in both data
1829   // structures in parallel.
1830   if (shape->element_type() == xla::TUPLE) {
1831     auto subshapes = shape->mutable_tuple_shapes();
1832 
1833     // 'layout' does not take the token attribute into account, so skip the
1834     // corresponding entry from xla shape proto.
1835     size_t subshapes_data_size = subshapes->size();
1836     if (!subshapes->empty() &&
1837         subshapes->Mutable(subshapes->size() - 1)->element_type() == xla::TOKEN)
1838       subshapes_data_size = subshapes->size() - 1;
1839 
1840     if (layout.size() != subshapes_data_size) {
1841       op->emitOpError() << "Expected layout of size " << layout.size()
1842                         << ", but found " << subshapes->size();
1843       return failure();
1844     }
1845     for (int i = 0; i < subshapes_data_size; i++) {
1846       mlir::Attribute child = layout[i];
1847       if (child.isa<mlir::UnitAttr>()) {
1848         // ignore unit attributes, they are used only for tokens.
1849         continue;
1850       }
1851       mlir::ArrayAttr c = child.dyn_cast<mlir::ArrayAttr>();
1852       if (!c) {
1853         op->emitOpError() << "Type Error: Expected layout array attribute";
1854         return failure();
1855       }
1856       if (failed(ConvertLayout(op, c, subshapes->Mutable(i)))) {
1857         return failure();
1858       }
1859     }
1860   } else {
1861     int rank = shape->dimensions().size();
1862     if (rank) {
1863       if (layout.size() != rank) {
1864         return failure();  // pass error down
1865       }
1866       std::vector<int64_t> array(rank);
1867       for (int i = 0; i < rank; i++) {
1868         mlir::IntegerAttr attr = layout[i].dyn_cast<mlir::IntegerAttr>();
1869         if (!attr) {
1870           op->emitOpError() << "Type Error: Expected layout integer attribute";
1871           return failure();
1872         }
1873         array[i] = attr.getInt();
1874       }
1875       *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto();
1876     }
1877   }
1878   return success();
1879 }
1880 
1881 // Assigns layouts from 'layout' to shape.
1882 // The function accepts any of the following shapes
1883 //   one or more array-shape(s) of infeed data
1884 //   Tuple(Tuple(zero or more array-shape w.r.t data), token_type)
1885 //
1886 // 'layout' of the mhlo.InfedOp 'op' is
1887 //    [zero or more layout for each array-shape w.r.t data]
1888 // 'layout_index' indexes into 'layout' accessing a layout corresponding to a
1889 // shape.
ConvertInfeedtLayout(mlir::Operation * op,const mlir::ArrayAttr & layout,xla::ShapeProto * shape,int64_t layout_index=0)1890 LogicalResult ConvertInfeedtLayout(mlir::Operation* op,
1891                                    const mlir::ArrayAttr& layout,
1892                                    xla::ShapeProto* shape,
1893                                    int64_t layout_index = 0) {
1894   if (shape->element_type() != xla::TUPLE) {
1895     // Handles following shape:
1896     //   single array-shape of infeed data
1897     mlir::ArrayAttr child_layout =
1898         layout[layout_index].dyn_cast<mlir::ArrayAttr>();
1899     if (!child_layout) {
1900       op->emitOpError() << "Type Error: Expected layout array attribute";
1901       return failure();
1902     }
1903 
1904     int rank = shape->dimensions().size();
1905     if (rank) {
1906       if (child_layout.size() != rank) {
1907         return failure();  // pass error down
1908       }
1909       std::vector<int64_t> array(rank);
1910       for (int i = 0; i < rank; i++) {
1911         mlir::IntegerAttr attr = child_layout[i].dyn_cast<mlir::IntegerAttr>();
1912         if (!attr) {
1913           op->emitOpError() << "Type Error: Expected layout integer attribute";
1914           return failure();
1915         }
1916         array[i] = attr.getInt();
1917       }
1918       *shape->mutable_layout() = xla::LayoutUtil::MakeLayout(array).ToProto();
1919     }
1920 
1921     return success();
1922   }
1923 
1924   auto subshapes = shape->mutable_tuple_shapes();
1925   auto datashape = subshapes->Mutable(0);
1926 
1927   if (datashape->element_type() == xla::TUPLE) {
1928     //   Handles following shapes:
1929     //     (Tuple(zero or more array-shape w.r.t data), token_type)
1930     auto data_subshapes = datashape->mutable_tuple_shapes();
1931     if (layout.size() != data_subshapes->size()) {
1932       op->emitOpError() << "Expected " << data_subshapes->size()
1933                         << " layout attribute(s) for infeed data, but found "
1934                         << layout.size();
1935       return failure();
1936     }
1937 
1938     for (int i = 0; i < data_subshapes->size(); i++) {
1939       if (failed(
1940               ConvertInfeedtLayout(op, layout, data_subshapes->Mutable(i), i)))
1941         return failure();
1942     }
1943   } else {
1944     //   Handles following shapes:
1945     //     array-shapes of two or more infeed data
1946     if (layout.size() != subshapes->size()) {
1947       op->emitOpError() << "Expected " << subshapes->size()
1948                         << " layout attribute(s) for infeed data, but found "
1949                         << layout.size();
1950       return failure();
1951     }
1952 
1953     for (int i = 0; i < subshapes->size(); i++) {
1954       if (failed(ConvertInfeedtLayout(op, layout, subshapes->Mutable(i), i)))
1955         return failure();
1956     }
1957   }
1958 
1959   return success();
1960 }
1961 
1962 // MHLO and XLA HLO disagree on the meaning of addition of `pred` / `i1`, so
1963 // there has to be a special case somewhere to account for the difference.  To
1964 // get the expected behavior of an `AddOp` on `i1`, we have to use `xor`.  Since
1965 // the majority of the conversion is generated code, we just sidestep it here
1966 // for this single case, and inline the code to emit an `xor`.
ExportXlaOperatorWrapped(mlir::Operation * inst,OpLoweringContext ctx)1967 LogicalResult ExportXlaOperatorWrapped(mlir::Operation* inst,
1968                                        OpLoweringContext ctx) {
1969   auto op = dyn_cast<mlir::mhlo::AddOp>(inst);
1970   if (op && op.getResult()
1971                 .getType()
1972                 .cast<mlir::TensorType>()
1973                 .getElementType()
1974                 .isSignlessInteger(1)) {
1975     auto& value_map = *ctx.values;
1976     auto result = op.getResult();
1977     xla::XlaOp xla_arg_0;
1978     if (failed(GetXlaOp(op.lhs(), value_map, &xla_arg_0, op)))
1979       return mlir::failure();
1980     xla::XlaOp xla_arg_1;
1981     if (failed(GetXlaOp(op.rhs(), value_map, &xla_arg_1, op)))
1982       return mlir::failure();
1983     auto xla_result = xla::Xor(Unwrap(xla_arg_0), Unwrap(xla_arg_1));
1984     value_map[result] = xla_result;
1985     return mlir::success();
1986   }
1987 
1988   return ExportXlaOperator(inst, ctx);
1989 }
1990 
Lower(mlir::Operation * inst,bool is_entry_function,llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering,xla::XlaOp * return_value)1991 LogicalResult ConvertToHloModule::Lower(
1992     mlir::Operation* inst, bool is_entry_function,
1993     llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
1994     xla::XlaBuilder* builder,
1995     ConvertToHloModule::ValueLoweringMap* value_lowering,
1996     xla::XlaOp* return_value) {
1997   // Explicitly fail for ops that are not supported for export.
1998   if (inst->getDialect() !=
1999           inst->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>() &&
2000       !mlir::isa<mlir::func::ConstantOp, mlir::arith::ConstantOp,
2001                  mlir::func::CallOp, mlir::tensor::CastOp,
2002                  mlir::func::ReturnOp>(inst)) {
2003     inst->emitOpError("unsupported op for export to XLA");
2004     return failure();
2005   }
2006 
2007   *return_value = xla::XlaOp();
2008 
2009   // See MlirToHloConversionOptions for more about layouts.
2010   auto propagate_layouts = [this](mlir::Operation* inst,
2011                                   xla::XlaOp xla_op) -> mlir::LogicalResult {
2012     if (options_.propagate_layouts) {
2013       auto* shape = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
2014                         ->mutable_shape();
2015       // TODO(kramm): merge this with ConvertLayout.
2016       *shape = ExtractXlaShape(inst).ToProto();
2017     }
2018 
2019     return success();
2020   };
2021 
2022   if (succeeded(
2023           ExportXlaOperatorWrapped(inst, {value_lowering, this, builder}))) {
2024     if (inst->getNumResults() == 1) {
2025       auto iter = value_lowering->find(inst->getResult(0));
2026       if (iter == value_lowering->end()) {
2027         inst->emitOpError(
2028             "inst has a result, but it's not found in value_lowering");
2029         return failure();
2030       }
2031       if (failed(propagate_layouts(inst, iter->second))) {
2032         return failure();
2033       }
2034     }
2035     // For infeed ops stemming back to InfeedDequeueTuple, respect the
2036     // layout attribute, and create the corresponding layout in hlo.
2037     if (isa<mhlo::InfeedOp>(inst)) {
2038       mlir::ArrayAttr layout =
2039           inst->getAttrOfType<mlir::ArrayAttr>(kLayoutAttr);
2040 
2041       if (layout) {
2042         // We propagate layout to the following three ops:
2043         // L1: For each data-result of mhlo.InfeedOp, we find the exported
2044         // xla::kGetTupleElement and propagate the layout.
2045         //
2046         // L2: For the token-result of mhlo.InfeedOp (result at last index),
2047         // we extract the xla::kInfeed op using the corresponding
2048         // xla::kGetTupleElement and propagate the layout to it.
2049         //
2050         // L3: In case there are non-zero data-results, there exists an
2051         // additional xla::kGetTupleElement accessing a tuple of the
2052         // data-results. We need to propagate the layout to that
2053         // xla::kGetTupleElement as well.
2054         auto num_results = inst->getNumResults();
2055         bool propagate_layout_to_data_tuple = true;
2056         for (unsigned i = 0; i < num_results; i++) {
2057           auto iter = value_lowering->find(inst->getResult(i));
2058           if (iter == value_lowering->end()) {
2059             inst->emitOpError() << "inst's result value at index " << i
2060                                 << " has no match in value_lowering";
2061             return failure();
2062           }
2063           auto xla_gte_op = iter->second;
2064           xla::HloInstructionProto* get_tuple_element_proto =
2065               xla::internal::XlaBuilderFriend::GetInstruction(xla_gte_op);
2066 
2067           assert(xla::StringToHloOpcode(get_tuple_element_proto->opcode())
2068                          .ValueOrDie() == xla::HloOpcode::kGetTupleElement &&
2069                  "The token-result of mhlo.InfeedOp should be mapped to a "
2070                  "xla::HloOpcode::kGetTupleElement");
2071 
2072           if (i == num_results - 1) {
2073             // L2
2074             xla::HloInstructionProto* xla_infeed_op_proto =
2075                 xla::internal::XlaBuilderFriend::GetInstructionByHandle(
2076                     xla_gte_op.builder(),
2077                     get_tuple_element_proto->operand_ids(0));
2078 
2079             assert(xla::StringToHloOpcode(xla_infeed_op_proto->opcode())
2080                            .ValueOrDie() == xla::HloOpcode::kInfeed &&
2081                    "Expected xla::HloOpcode::kInfeed op");
2082 
2083             auto* shape = xla_infeed_op_proto->mutable_shape();
2084             if (failed(ConvertInfeedtLayout(inst, layout, shape)))
2085               return failure();
2086 
2087           } else {
2088             // L1
2089             auto* shape = get_tuple_element_proto->mutable_shape();
2090             if (failed(ConvertInfeedtLayout(inst, layout, shape, i)))
2091               return failure();
2092 
2093             // L3
2094             if (propagate_layout_to_data_tuple) {
2095               xla::HloInstructionProto* data_tuple_proto =
2096                   xla::internal::XlaBuilderFriend::GetInstructionByHandle(
2097                       xla_gte_op.builder(),
2098                       get_tuple_element_proto->operand_ids(0));
2099               auto* data_tuple_shape = data_tuple_proto->mutable_shape();
2100 
2101               assert(xla::StringToHloOpcode(data_tuple_proto->opcode())
2102                              .ValueOrDie() ==
2103                          xla::HloOpcode::kGetTupleElement &&
2104                      "Expected a xla:tupleOp for all the data results.");
2105               if (failed(ConvertInfeedtLayout(inst, layout, data_tuple_shape)))
2106                 return failure();
2107             }
2108             propagate_layout_to_data_tuple = false;
2109           }
2110         }
2111       }
2112     }
2113     return success();
2114   }
2115 
2116   auto& value_map = *value_lowering;
2117   ElementsAttr const_attr;
2118 
2119   if (auto call_op = dyn_cast<mlir::func::CallOp>(inst)) {
2120     return LowerFunctionCall(call_op, builder, &value_map);
2121   }
2122 
2123   if (auto op = dyn_cast<mlir::tensor::CastOp>(inst)) {
2124     Value operand = op.getOperand();
2125     auto ty = operand.getType().dyn_cast<ShapedType>();
2126     // If this was a cast from a static or bounded tensors, then it is a noop
2127     // for export to HLO and we can use the operand.
2128     if (!ty || !IsBoundedOrStatic(ty)) {
2129       inst->emitOpError()
2130           << "requires static or bounded operand for HLO translation";
2131       return failure();
2132     }
2133 
2134     xla::XlaOp xla_operand;
2135     if (failed(GetXlaOp(operand, value_map, &xla_operand, op)))
2136       return failure();
2137     value_map[op.getResult()] = xla_operand;
2138     if (failed(propagate_layouts(inst, xla_operand))) {
2139       return failure();
2140     }
2141     return success();
2142   }
2143 
2144   if (matchPattern(inst, m_Constant(&const_attr))) {
2145     if (!inst->getResult(0).getType().isa<ShapedType>()) {
2146       return inst->emitError(
2147           "expected shaped type during constant mhlo -> hlo translation");
2148     }
2149 
2150     auto literal_or =
2151         CreateArrayLiteralFromAttr(const_attr, ExtractXlaShape(inst).layout());
2152     if (!literal_or.ok())
2153       return inst->emitError(literal_or.status().ToString());
2154     auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
2155     value_map[inst->getResult(0)] = constant;
2156 
2157     return success();
2158   }
2159 
2160   if (isa<mhlo::ReturnOp, mlir::func::ReturnOp>(inst)) {
2161     // Construct the return value for the function. If there is a single value
2162     // returned, then return it directly, else create a tuple and return.
2163     unsigned num_return_values = inst->getNumOperands();
2164     const bool has_ret_shardings =
2165         !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
2166     if ((return_tuple_ && is_entry_function) || num_return_values != 1) {
2167       std::vector<xla::XlaOp> returns(num_return_values);
2168       for (OpOperand& ret : inst->getOpOperands()) {
2169         unsigned index = ret.getOperandNumber();
2170         xla::XlaOp operand;
2171         if (failed(GetXlaOp(ret.get(), value_map, &operand, inst)))
2172           return failure();
2173 
2174         returns[index] = operand;
2175         if (!is_entry_function || !has_ret_shardings) continue;
2176 
2177         xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
2178         StatusOr<xla::XlaOp> reshape =
2179             ReshapeWithCorrectRepresentationAndSharding(
2180                 builder, returns[index], return_shape,
2181                 options_.layout_preference_fn, options_.shape_representation_fn,
2182                 ret_shardings[index], /*fast_mem=*/false);
2183         if (!reshape.ok())
2184           return inst->emitError() << reshape.status().error_message();
2185 
2186         returns[index] = reshape.ValueOrDie();
2187       }
2188 
2189       if (has_ret_shardings) {
2190         xla::OpSharding sharding;
2191         sharding.set_type(xla::OpSharding::TUPLE);
2192         for (auto& ret_sharding : ret_shardings)
2193           *sharding.add_tuple_shardings() = *ret_sharding;
2194 
2195         builder->SetSharding(sharding);
2196       }
2197 
2198       *return_value = xla::Tuple(builder, returns);
2199       builder->ClearSharding();
2200     } else if (num_return_values == 1) {
2201       xla::XlaOp operand;
2202       if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst)))
2203         return failure();
2204 
2205       if (has_ret_shardings) {
2206         auto tuple = Tuple(builder, {operand});
2207         builder->SetSharding(*ret_shardings[0]);
2208         *return_value = GetTupleElement(tuple, 0);
2209         builder->ClearSharding();
2210       } else {
2211         *return_value = operand;
2212       }
2213     }
2214 
2215     return success();
2216   }
2217 
2218   inst->emitOpError() << "can't be translated to XLA HLO";
2219   return failure();
2220 }
2221 
LowerFunctionCall(mlir::func::CallOp call_op,xla::XlaBuilder * builder,ConvertToHloModule::ValueLoweringMap * value_lowering)2222 LogicalResult ConvertToHloModule::LowerFunctionCall(
2223     mlir::func::CallOp call_op, xla::XlaBuilder* builder,
2224     ConvertToHloModule::ValueLoweringMap* value_lowering) {
2225   auto& value_map = *value_lowering;
2226   mlir::func::FuncOp callee =
2227       module_.lookupSymbol<mlir::func::FuncOp>(call_op.getCallee());
2228   if (failed(RunOnFunction(callee))) return failure();
2229   std::vector<xla::XlaOp> operands;
2230   for (auto operand : call_op.getOperands()) {
2231     xla::XlaOp xla_operand;
2232     if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op)))
2233       return failure();
2234     operands.push_back(xla_operand);
2235   }
2236   // Each call to xla::Call would insert a copy of the computation to
2237   // the HLO. Thus each callsite would have a unique callee in the
2238   // exported HLO. HLO syntactically does not require all calls to have unique
2239   // callees, but eventually before lowering call graph is "flattened" to
2240   // make that true. This is done before lowering because buffer assignment
2241   // needs this invariant.
2242   xla::XlaOp call_result =
2243       xla::Call(builder, lowered_computation_[callee], operands);
2244   // Use GetTupleElement for multiple outputs
2245   unsigned num_results = call_op.getNumResults();
2246   if (num_results > 1) {
2247     for (unsigned i = 0; i != num_results; ++i) {
2248       value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i);
2249     }
2250   } else if (num_results == 1) {
2251     value_map[call_op.getResult(0)] = call_result;
2252   }
2253   return success();
2254 }
2255 
RunOnFunction(mlir::func::FuncOp f)2256 LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) {
2257   if (lowered_computation_.count(f)) return success();
2258   if (!llvm::hasSingleElement(f)) {
2259     return f.emitError("only single block Function supported");
2260   }
2261 
2262   // Create a sub-builder if this is not the main function.
2263   std::unique_ptr<xla::XlaBuilder> builder_up;
2264   bool entry_function = f.getName() == "main";
2265   if (!entry_function)
2266     builder_up = module_builder_.CreateSubBuilder(f.getName().str());
2267   auto& builder = entry_function ? module_builder_ : *builder_up;
2268 
2269   xla::XlaComputation computation;
2270   std::vector<bool> entry_args_same_across_replicas;
2271   llvm::SmallVector<std::optional<xla::OpSharding>, 4> arg_shardings;
2272   llvm::SmallVector<std::optional<xla::OpSharding>, 4> ret_shardings;
2273   if (entry_function) {
2274     bool any_arg_replicated = false;
2275     entry_args_same_across_replicas.reserve(f.getNumArguments());
2276     for (int64_t i = 0; i < f.getNumArguments(); ++i) {
2277       auto attr = f.getArgAttrOfType<mlir::UnitAttr>(i, kReplicationAttr);
2278       entry_args_same_across_replicas.push_back(attr != nullptr);
2279       any_arg_replicated |= entry_args_same_across_replicas.back();
2280       // Pass the alias info to the builder so that it will build the alias info
2281       // into the resulting HloModule.
2282       auto aliasing_output =
2283           f.getArgAttrOfType<mlir::IntegerAttr>(i, "tf.aliasing_output");
2284       if (!aliasing_output) continue;
2285       xla::ShapeIndex output_index;
2286       if ((return_tuple_ && entry_function) || f.getNumResults() != 1) {
2287         output_index = {aliasing_output.getInt()};
2288       } else {
2289         if (aliasing_output.getInt() != 0) {
2290           return f.emitError(
2291               "Aliasing output must be 0 if only one output exists");
2292         }
2293         output_index = {};
2294       }
2295       if (use_tuple_args_) {
2296         builder.SetUpAlias(output_index, /*param_number=*/0,
2297                            /*param_index=*/{i});
2298       } else {
2299         builder.SetUpAlias(output_index, /*param_number=*/i,
2300                            /*param_index=*/{});
2301       }
2302     }
2303     // Do not populate this field when nothing is replicated, since empty field
2304     // means no replication. This avoids the need for unrelated tests to handle
2305     // this field.
2306     if (!any_arg_replicated) entry_args_same_across_replicas.clear();
2307 
2308     ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
2309   }
2310   if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function,
2311                                        false, entry_args_same_across_replicas,
2312                                        arg_shardings, ret_shardings,
2313                                        &computation))) {
2314     return failure();
2315   }
2316   lowered_computation_[f] = std::move(computation);
2317   return success();
2318 }
2319 
SetEntryTupleShapesAndLeafReplication(Block * block,const std::vector<bool> & entry_args_same_across_replicas,llvm::SmallVectorImpl<xla::Shape> * arg_shapes,std::vector<bool> * leaf_replication)2320 LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
2321     Block* block, const std::vector<bool>& entry_args_same_across_replicas,
2322     llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
2323     std::vector<bool>* leaf_replication) {
2324   arg_shapes->reserve(block->getNumArguments());
2325   leaf_replication->reserve(block->getNumArguments());
2326   for (BlockArgument& arg : block->getArguments()) {
2327     arg_shapes->push_back(xla::TypeToShape(arg.getType()));
2328     xla::Shape& arg_shape = arg_shapes->back();
2329     auto layout_preference_status =
2330         options_.layout_preference_fn ? options_.layout_preference_fn(arg_shape)
2331                                       : XlaLayoutPreference::kNoPreference;
2332     if (!layout_preference_status.ok())
2333       return block->getParentOp()->emitError()
2334              << layout_preference_status.status().error_message();
2335 
2336     auto arg_shape_status = options_.shape_representation_fn
2337                                 ? options_.shape_representation_fn(
2338                                       arg_shape, /*use_fast_memory=*/false,
2339                                       layout_preference_status.ValueOrDie())
2340                                 : arg_shape;
2341     if (!arg_shape_status.ok())
2342       return block->getParentOp()->emitError()
2343              << arg_shape_status.status().error_message();
2344 
2345     arg_shape = std::move(arg_shape_status.ValueOrDie());
2346 
2347     if (entry_args_same_across_replicas.empty()) continue;
2348     for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
2349       leaf_replication->push_back(
2350           entry_args_same_across_replicas[arg.getArgNumber()]);
2351   }
2352 
2353   return success();
2354 }
2355 
SetEntryTupleShardings(Block * block,xla::XlaBuilder * builder,llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,llvm::SmallVectorImpl<xla::Shape> * arg_shapes)2356 LogicalResult ConvertToHloModule::SetEntryTupleShardings(
2357     Block* block, xla::XlaBuilder* builder,
2358     llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
2359     llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
2360   if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
2361     xla::OpSharding sharding;
2362     sharding.set_type(xla::OpSharding::TUPLE);
2363     for (const auto& arg_sharding : llvm::enumerate(arg_shardings)) {
2364       auto hlo_sharding = xla::HloSharding::FromProto(*arg_sharding.value());
2365       if (!hlo_sharding.ok())
2366         return block->getParentOp()->emitError()
2367                << hlo_sharding.status().error_message();
2368 
2369       auto status = RewriteLayoutWithShardedShape(
2370           hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
2371           options_.layout_preference_fn, options_.shape_representation_fn,
2372           &(*arg_shapes)[arg_sharding.index()]);
2373       if (!status.ok())
2374         return block->getParentOp()->emitError() << status.error_message();
2375 
2376       *sharding.add_tuple_shardings() = *arg_sharding.value();
2377     }
2378 
2379     builder->SetSharding(sharding);
2380   }
2381 
2382   return success();
2383 }
2384 
LowerBasicBlockAsFunction(Block * block,xla::XlaBuilder * builder,bool is_entry_function,bool ensure_single_arg,const std::vector<bool> & entry_args_same_across_replicas,llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,xla::XlaComputation * result,llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands)2385 LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
2386     Block* block, xla::XlaBuilder* builder, bool is_entry_function,
2387     bool ensure_single_arg,
2388     const std::vector<bool>& entry_args_same_across_replicas,
2389     llvm::ArrayRef<std::optional<xla::OpSharding>> arg_shardings,
2390     llvm::ArrayRef<std::optional<xla::OpSharding>> ret_shardings,
2391     xla::XlaComputation* result,
2392     llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands) {
2393   // Mapping from the Value to lowered XlaOp.
2394   ValueLoweringMap lowering;
2395 
2396   // If using tuples as input, then there is only one input parameter that is a
2397   // tuple.
2398   if (is_entry_function && use_tuple_args_) {
2399     llvm::SmallVector<xla::Shape, 4> arg_shapes;
2400     std::vector<bool> leaf_replication;
2401     if (failed(SetEntryTupleShapesAndLeafReplication(
2402             block, entry_args_same_across_replicas, &arg_shapes,
2403             &leaf_replication)))
2404       return failure();
2405 
2406     if (failed(
2407             SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
2408       return failure();
2409 
2410     xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
2411     auto tuple =
2412         xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
2413     builder->ClearSharding();
2414 
2415     bool set_tuple_element_sharding =
2416         !arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings);
2417     for (BlockArgument& arg : block->getArguments()) {
2418       if (set_tuple_element_sharding)
2419         builder->SetSharding(*arg_shardings[arg.getArgNumber()]);
2420       lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
2421     }
2422     builder->ClearSharding();
2423   } else {
2424     if (ensure_single_arg) {
2425       // Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp.
2426       llvm::SmallVector<xla::Shape, 4> arg_shapes;
2427 
2428       auto args_size = block->getNumArguments();
2429       if (implicit_operands) args_size = implicit_operands->size();
2430 
2431       arg_shapes.reserve(args_size);
2432       if (implicit_operands) {
2433         for (auto implicit_operand : *implicit_operands)
2434           arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType()));
2435       } else {
2436         for (BlockArgument& arg : block->getArguments())
2437           arg_shapes.push_back(xla::TypeToShape(arg.getType()));
2438       }
2439 
2440       if (args_size > 1) {
2441         auto tuple = xla::Parameter(builder, 0,
2442                                     xla::ShapeUtil::MakeTupleShape(arg_shapes),
2443                                     "arg_tuple");
2444 
2445         if (implicit_operands) {
2446           int arg_index = 0;
2447           for (auto implicit_operand : *implicit_operands)
2448             lowering[implicit_operand] =
2449                 xla::GetTupleElement(tuple, arg_index++);
2450         } else {
2451           for (BlockArgument& arg : block->getArguments())
2452             lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
2453         }
2454       } else if (args_size == 1) {
2455         if (implicit_operands) {
2456           lowering[(*implicit_operands)[0]] =
2457               xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
2458         } else {
2459           lowering[block->getArgument(0)] =
2460               xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
2461         }
2462       } else {
2463         // Applicable only for IfOp or CaseOp. No implicit operands implies no
2464         // xla parameters. In this case, we create an empty tuple as the
2465         // block-parameter.
2466         xla::Parameter(builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes),
2467                        "arg_empty_tuple");
2468       }
2469     } else {
2470       for (BlockArgument& arg : block->getArguments()) {
2471         auto num = arg.getArgNumber();
2472         xla::Shape shape = xla::TypeToShape(arg.getType());
2473         if (!arg_shardings.empty() && arg_shardings[num]) {
2474           builder->SetSharding(*arg_shardings[num]);
2475         }
2476         if (entry_args_same_across_replicas.empty()) {
2477           lowering[arg] =
2478               xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
2479         } else {
2480           lowering[arg] = xla::Parameter(
2481               builder, num, shape, absl::StrCat("Arg_", num),
2482               std::vector<bool>(entry_args_same_across_replicas[num],
2483                                 xla::ShapeUtil::GetLeafCount(shape)));
2484         }
2485         builder->ClearSharding();
2486       }
2487     }
2488   }
2489 
2490   xla::XlaOp return_value;
2491   for (auto& inst : *block)
2492     if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
2493                      &lowering, &return_value)))
2494       return failure();
2495 
2496   // Build the XlaComputation and check for failures.
2497   auto computation_or =
2498       return_value.valid() ? builder->Build(return_value) : builder->Build();
2499   if (!computation_or.ok()) {
2500     block->back().emitError(
2501         llvm::Twine(computation_or.status().error_message()));
2502     return failure();
2503   }
2504   *result = std::move(computation_or.ValueOrDie());
2505   return success();
2506 }
2507 
LowerRegionAsComputation(mlir::Region * region,xla::XlaComputation * func,llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands,bool ensure_single_arg)2508 LogicalResult ConvertToHloModule::LowerRegionAsComputation(
2509     mlir::Region* region, xla::XlaComputation* func,
2510     llvm::Optional<llvm::ArrayRef<mlir::Value>> implicit_operands,
2511     bool ensure_single_arg) {
2512   std::unique_ptr<xla::XlaBuilder> builder =
2513       module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
2514   return LowerBasicBlockAsFunction(&region->front(), builder.get(),
2515                                    /*is_entry_function=*/false,
2516                                    /*ensure_single_arg*/ ensure_single_arg,
2517                                    /*entry_args_same_across_replicas=*/{},
2518                                    /*arg_shardings=*/{}, /*ret_shardings=*/{},
2519                                    func, implicit_operands);
2520 }
2521 
AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto * binding,int arg_index,int32_t shape_index,int32_t padding_arg_index,bool use_tuple_args)2522 void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
2523                                      int arg_index, int32_t shape_index,
2524                                      int32_t padding_arg_index,
2525                                      bool use_tuple_args) {
2526   auto* entry = binding->add_entries();
2527   entry->set_target_param_dim_num(shape_index);
2528   if (use_tuple_args) {
2529     entry->set_target_param_num(0);
2530     entry->add_target_param_index(arg_index);
2531     entry->set_dynamic_param_num(0);
2532     entry->add_dynamic_param_index(padding_arg_index);
2533   } else {
2534     entry->set_target_param_num(arg_index);
2535     entry->set_dynamic_param_num(padding_arg_index);
2536   }
2537 }
2538 
2539 // Runs the PrepareForExport pass on the ModuleOp.
PrepareForExport(mlir::ModuleOp module)2540 Status PrepareForExport(mlir::ModuleOp module) {
2541   // Prepare for export to XLA HLO.
2542   mlir::PassManager pm(module.getContext());
2543   pm.addNestedPass<mlir::func::FuncOp>(mhlo::CreatePrepareForExport());
2544   if (failed(pm.run(module)))
2545     return tensorflow::errors::Internal("Unable to optimize for XLA export");
2546   return ::tensorflow::OkStatus();
2547 }
2548 
2549 }  // namespace
2550 
ConvertRegionToComputation(mlir::Region * region,xla::XlaComputation * func,MlirToHloConversionOptions options)2551 Status ConvertRegionToComputation(mlir::Region* region,
2552                                   xla::XlaComputation* func,
2553                                   MlirToHloConversionOptions options) {
2554   mlir::ModuleOp module;
2555   xla::XlaBuilder module_builder("main");
2556   ConvertToHloModule converter(module, module_builder, true, true, options);
2557   if (failed(converter.LowerRegionAsComputation(region, func)))
2558     return tensorflow::errors::Internal(
2559         "failed to convert region to computation");
2560   return ::tensorflow::OkStatus();
2561 }
2562 
ConvertMlirHloToHlo(mlir::ModuleOp module,xla::HloProto * hlo_proto,bool use_tuple_args,bool return_tuple,MlirToHloConversionOptions options)2563 Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
2564                            bool use_tuple_args, bool return_tuple,
2565                            MlirToHloConversionOptions options) {
2566   TF_RETURN_IF_ERROR(PrepareForExport(module));
2567   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
2568   xla::XlaBuilder module_builder("main");
2569   ConvertToHloModule converter(module, module_builder, use_tuple_args,
2570                                return_tuple, options);
2571   if (failed(converter.Run())) return diag_handler.ConsumeStatus();
2572   auto hlo_module = converter.ConsumeMainProto();
2573   StringRef module_name = module.getName() ? *module.getName() : "main";
2574   hlo_module.set_name(module_name.str());
2575   hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
2576   return ::tensorflow::OkStatus();
2577 }
2578 
BuildHloFromMlirHlo(mlir::Block & block,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,MlirToHloConversionOptions options)2579 Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder,
2580                            llvm::ArrayRef<xla::XlaOp> xla_params,
2581                            std::vector<xla::XlaOp>& returns,
2582                            MlirToHloConversionOptions options) {
2583   auto module = block.getParentOp()->getParentOfType<mlir::ModuleOp>();
2584   TF_RETURN_IF_ERROR(PrepareForExport(module));
2585   ConvertToHloModule converter(module, builder,
2586                                /*use_tuple_args=*/false, /*return_tuple=*/false,
2587                                options);
2588 
2589   ConvertToHloModule::ValueLoweringMap lowering;
2590   // xla_params should only include non-constant parameters the block arguments
2591   // correspond to.
2592   if (xla_params.size() != block.getArguments().size())
2593     return tensorflow::errors::Internal("xla_params size (", xla_params.size(),
2594                                         ") != block arguments size (",
2595                                         block.getArguments().size(), ")");
2596   for (BlockArgument& arg : block.getArguments()) {
2597     auto num = arg.getArgNumber();
2598     lowering[arg] = xla_params[num];
2599   }
2600 
2601   mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
2602   for (auto& inst : block) {
2603     if (isa<mhlo::ReturnOp, mlir::func::ReturnOp>(inst)) {
2604       returns.resize(inst.getNumOperands());
2605       for (OpOperand& ret : inst.getOpOperands()) {
2606         unsigned index = ret.getOperandNumber();
2607         xla::XlaOp operand;
2608         if (failed(GetXlaOp(ret.get(), lowering, &operand, &inst)))
2609           return diag_handler.ConsumeStatus();
2610         returns[index] = operand;
2611       }
2612     } else {
2613       xla::XlaOp return_value;
2614       if (failed(converter.Lower(&inst, /*is_entry_function=*/true,
2615                                  /*ret_shardings=*/{}, &builder, &lowering,
2616                                  &return_value)))
2617         return diag_handler.ConsumeStatus();
2618     }
2619   }
2620 
2621   return ::tensorflow::OkStatus();
2622 }
2623 
2624 }  // namespace mlir
2625