xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
17 
18 #include <cctype>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/ErrorHandling.h"
32 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
33 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
34 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
36 #include "mlir/Dialect/Traits.h"  // from @llvm-project
37 #include "mlir/IR/Attributes.h"  // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
40 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
41 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
42 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
43 #include "mlir/IR/Matchers.h"  // from @llvm-project
44 #include "mlir/IR/Operation.h"  // from @llvm-project
45 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
46 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
47 #include "mlir/IR/Types.h"  // from @llvm-project
48 #include "mlir/Pass/Pass.h"  // from @llvm-project
49 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
51 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
52 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
53 #include "tensorflow/compiler/mlir/xla/transforms/utils.h"
54 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
55 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
56 #include "tensorflow/compiler/xla/client/padding.h"
57 #include "tensorflow/compiler/xla/client/sharding_builder.h"
58 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
59 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
60 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h"
61 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/hlo_utils.h"
62 #include "tensorflow/compiler/xla/xla_data.pb.h"
63 #include "tensorflow/core/framework/kernel_shape_util.h"
64 #include "tensorflow/core/framework/rng_alg.h"
65 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
66 #include "tensorflow/core/platform/bfloat16.h"
67 #include "tensorflow/core/util/padding.h"
68 #include "tensorflow/core/util/tensor_format.h"
69 
70 namespace mlir {
71 namespace mhlo {
72 namespace {
73 
74 constexpr char kShardingAttr[] = "mhlo.sharding";
75 
76 /// Returns the feature dimension for the given format and input type.
GetFeatureDimension(tensorflow::TensorFormat format,RankedTensorType input_ty)77 static size_t GetFeatureDimension(tensorflow::TensorFormat format,
78                                   RankedTensorType input_ty) {
79   return GetTensorFeatureDimIndex(input_ty.getRank(), format);
80 }
81 
82 // Gets all integer values from the given attribute and push them to `values`.
GetI64ArrayAttrValues(Attribute attr,SmallVectorImpl<int64_t> * values)83 void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl<int64_t> *values) {
84   auto array_attr = attr.cast<ArrayAttr>();
85   values->reserve(array_attr.getValue().size());
86   for (Attribute val : array_attr.getValue())
87     values->push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
88 }
89 
90 // Returns 1D 32-bit dense elements attribute with the given values.
GetI32ElementsAttr(ArrayRef<int32_t> values,Builder * builder)91 static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
92                                                Builder *builder) {
93   RankedTensorType ty = RankedTensorType::get(
94       {static_cast<int32_t>(values.size())}, builder->getIntegerType(32));
95   return DenseIntElementsAttr::get(ty, values);
96 }
97 
98 // Returns a 1-d i64 elements attribute populated with numbers from start to
99 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)100 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
101                                                      Builder *builder) {
102   int size = end - start;
103 
104   SmallVector<int64_t, 4> vals;
105   vals.resize(size);
106   std::iota(vals.begin(), vals.end(), start);
107 
108   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
109   return DenseIntElementsAttr::get(ty, vals);
110 }
111 
112 // Returns a 1-d i64 elements attribute populated with `val` repeated `size`
113 // times.
GetI64ElementsAttrForValue(int size,int64_t val,Builder * builder)114 static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val,
115                                                        Builder *builder) {
116   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
117   return DenseIntElementsAttr::get(ty, val);
118 }
119 
120 // Returns the corresponding type that should be used for performing sum
121 // accumulation over the given input type.
GetSumAccumulationType(Type input_type)122 Type GetSumAccumulationType(Type input_type) {
123   MLIRContext *ctx = input_type.getContext();
124   if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
125   if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
126     return IntegerType::get(ctx, 32);
127   return input_type;
128 }
129 
130 // Returns axis in HLO format from TF elements attr with exactly one element or
131 // is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow
132 // format supports negative indexing unlike HLO.
GetHLOAxisFromTFAxis(Attribute attr,int64_t rank,Builder * b)133 static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank,
134                                         Builder *b) {
135   IntegerAttr intAttr = attr.dyn_cast_or_null<IntegerAttr>();
136   if (auto elementAttr = attr.dyn_cast_or_null<ElementsAttr>()) {
137     SmallVector<uint64_t, 1> index(elementAttr.getType().getRank(), 0);
138     intAttr = elementAttr.getValues<IntegerAttr>()[index];
139   }
140 
141   assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis");
142 
143   int64_t axis = intAttr.getInt();
144   if (axis < 0) {
145     axis += rank;
146   }
147   return b->getI64IntegerAttr(axis);
148 }
149 
150 // If `value` is an IntegerAttr, returns the integer value for the HLO axis
151 // corresponding to the tensorflow axis. In particular, the tensorflow axis can
152 // be negative, in which case, the corresponding HLO axis is
153 // (axis + rank-of-the-tensor).
GetIntegerHLOAxisFromTFAxis(Value value,int64_t rank)154 static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
155                                                            int64_t rank) {
156   DenseIntElementsAttr attrs;
157   if (!matchPattern(value, m_Constant(&attrs)) ||
158       attrs.getType().getRank() != 0) {
159     return llvm::None;
160   }
161   int64_t axis = attrs.getValues<IntegerAttr>()[0].getInt();
162   return axis < 0 ? axis + rank : axis;
163 }
164 
165 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
166 /// the shape of the input value.
CastValueToI64(Location loc,Value value,PatternRewriter * rewriter)167 static ConvertOp CastValueToI64(Location loc, Value value,
168                                 PatternRewriter *rewriter) {
169   return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
170 }
171 
172 // Creates an unpack op along the 0th dimension of the tensor. The `value` input
173 // must be a ranked tensor.
UnpackTensorAlongZeroDim(Location loc,Value value,PatternRewriter * rewriter)174 static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value,
175                                              PatternRewriter *rewriter) {
176   auto indices_type = value.getType().cast<RankedTensorType>();
177   int num_outputs = indices_type.getShape().front();
178   SmallVector<Type, 2> unpacked_indices_type(
179       num_outputs, RankedTensorType::get({}, indices_type.getElementType()));
180   auto unpacked_indices = rewriter->create<TF::UnpackOp>(
181       loc, unpacked_indices_type, value,
182       IntegerAttr::get(rewriter->getIntegerType(64), 0));
183   return unpacked_indices;
184 }
185 
186 // Returns size of dimension at the specified index, if ranked tensor.
187 // Otherwise, returns -1.
188 //
189 // Aborts if the type is ranked but doesn't have the dimension.
GetDimSize(Type ty,int64_t index)190 int64_t GetDimSize(Type ty, int64_t index) {
191   RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
192   if (!ranked_ty) return -1;
193 
194   return ranked_ty.getDimSize(index);
195 }
196 
197 template <typename T, int num_dims>
ToTensorShape(llvm::ArrayRef<T> sizes)198 tensorflow::TensorShape ToTensorShape(llvm::ArrayRef<T> sizes) {
199   return tensorflow::TensorShape(
200       llvm::SmallVector<int64_t, num_dims>(sizes.begin(), sizes.end()));
201 }
202 
203 template <typename T, int num_dims>
ToTensorShape(llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes)204 tensorflow::TensorShape ToTensorShape(
205     llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes) {
206   return tensorflow::TensorShape(
207       llvm::SmallVector<int64_t, num_dims>(sizes.begin(), sizes.end()));
208 }
209 
210 // Returns a limit scalar const op for the given type.
211 // Requires FloatType or IntegerType
GetScalarLimitConstOfType(Type ty,Location loc,hlo::ScalarLimit limit,OpBuilder * builder)212 static ConstantOp GetScalarLimitConstOfType(Type ty, Location loc,
213                                             hlo::ScalarLimit limit,
214                                             OpBuilder *builder) {
215   return builder->create<ConstantOp>(loc, hlo::getScalarLimitOfType(ty, limit));
216 }
217 
218 // Creates an mhlo::SliceOp where the major dimensions have full size, and
219 // the minor dimensions have the provided offsets and sizes.
SliceInMinorDims(Location loc,Value v,ArrayRef<int64_t> minor_starts,ArrayRef<int64_t> minor_limits,OpBuilder * builder)220 static Value SliceInMinorDims(Location loc, Value v,
221                               ArrayRef<int64_t> minor_starts,
222                               ArrayRef<int64_t> minor_limits,
223                               OpBuilder *builder) {
224   auto type = v.getType().cast<RankedTensorType>();
225   llvm::SmallVector<int64_t, 4> slice_starts(type.getRank(), 0);
226   int64_t major_dims = type.getRank() - minor_starts.size();
227   std::copy(minor_starts.begin(), minor_starts.end(),
228             slice_starts.begin() + major_dims);
229   auto slice_limits = llvm::to_vector<4>(type.getShape());
230   std::copy(minor_limits.begin(), minor_limits.end(),
231             slice_limits.begin() + major_dims);
232   llvm::SmallVector<int64_t, 4> slice_strides(type.getRank(), 1);
233   return builder->create<SliceOp>(loc, v,
234                                   GetI64ElementsAttr(slice_starts, builder),
235                                   GetI64ElementsAttr(slice_limits, builder),
236                                   GetI64ElementsAttr(slice_strides, builder));
237 }
238 
239 // Creates a vector of index values:
240 //  [0, 0, ..., minor_indices[0], minor_indices[1], ... minor_indices[-1]]
241 // with length `rank`.
CreateFullIndexVectorFromMinorIndices(Location loc,ArrayRef<Value> minor_indices,int64_t rank,OpBuilder * builder)242 static llvm::SmallVector<Value, 4> CreateFullIndexVectorFromMinorIndices(
243     Location loc, ArrayRef<Value> minor_indices, int64_t rank,
244     OpBuilder *builder) {
245   auto zero =
246       GetScalarConstOfType(getElementTypeOrSelf(minor_indices[0].getType()),
247                            loc, 0, builder)
248           .output();
249   llvm::SmallVector<Value, 4> indices(rank, zero);
250   std::copy(minor_indices.begin(), minor_indices.end(),
251             indices.begin() + (rank - minor_indices.size()));
252   return indices;
253 }
254 
255 // Creates an mhlo::DynamicSliceOp where the major dimensions have full size,
256 // and the minor dimensions have the provided offsets and sizes.
DynamicSliceInMinorDims(Location loc,Value v,ArrayRef<Value> minor_starts,ArrayRef<int64_t> minor_sizes,OpBuilder * builder)257 static Value DynamicSliceInMinorDims(Location loc, Value v,
258                                      ArrayRef<Value> minor_starts,
259                                      ArrayRef<int64_t> minor_sizes,
260                                      OpBuilder *builder) {
261   if (minor_starts.empty()) return v;
262   auto type = v.getType().cast<RankedTensorType>();
263   auto slice_starts = CreateFullIndexVectorFromMinorIndices(
264       loc, minor_starts, type.getRank(), builder);
265   int64_t major_dims = type.getRank() - minor_starts.size();
266   auto slice_sizes = llvm::to_vector<4>(type.getShape());
267   std::copy(minor_sizes.begin(), minor_sizes.end(),
268             slice_sizes.begin() + major_dims);
269   return builder->create<mhlo::DynamicSliceOp>(
270       loc, v, slice_starts, GetI64ElementsAttr(slice_sizes, builder));
271 }
272 
273 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
274 // offsets, and the minor dimensions have the provided offsets.
DynamicUpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<Value> minor_starts,OpBuilder * builder)275 static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update,
276                                            ArrayRef<Value> minor_starts,
277                                            OpBuilder *builder) {
278   if (minor_starts.empty()) return v;
279   auto type = v.getType().cast<RankedTensorType>();
280   auto dus_starts = CreateFullIndexVectorFromMinorIndices(
281       loc, minor_starts, type.getRank(), builder);
282   return builder->create<DynamicUpdateSliceOp>(loc, type, v, update,
283                                                llvm::makeArrayRef(dus_starts));
284 }
285 
286 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
287 // offsets, and the minor dimensions have the provided static offsets.
UpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<int64_t> minor_starts,OpBuilder * builder)288 static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
289                                     ArrayRef<int64_t> minor_starts,
290                                     OpBuilder *builder) {
291   llvm::SmallVector<Value, 4> dus_starts(minor_starts.size());
292   for (uint64_t i = 0; i < minor_starts.size(); ++i) {
293     dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc,
294                                          minor_starts[i], builder);
295   }
296   return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder);
297 }
298 
299 // Deprecated: This is maintained to aid in porting old code that is not yet
300 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
301 // Gets the resulting type from a broadcast between two types for statically
302 // shaped types. This is to be used for legacy lowerings that both use non
303 // left-padded broadcasting and static shapes. Its use should not be permitted
304 // in new code.
305 // May return nullptr on invalid static broadcast dimensions.
306 // ABSL_DEPRECATED()
GetStaticBroadcastType(RankedTensorType x,RankedTensorType y,DenseIntElementsAttr broadcast_dimensions_attr)307 static RankedTensorType GetStaticBroadcastType(
308     RankedTensorType x, RankedTensorType y,
309     DenseIntElementsAttr broadcast_dimensions_attr) {
310   auto element_type = x.getElementType();
311   auto shape_x = x.getShape();
312   auto shape_y = y.getShape();
313 
314   if (shape_x.size() == shape_y.size()) {
315     llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
316     for (int i = 0; i < shape_x.size(); i++) {
317       auto x_val = shape_x[i];
318       auto y_val = shape_y[i];
319       out_shape[i] = std::max(x_val, y_val);
320     }
321     return RankedTensorType::get(out_shape, element_type);
322   }
323 
324   auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
325   auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
326 
327   llvm::SmallVector<int64_t, 4> broadcast_dimensions;
328   // Explicit broadcast dimensions.
329   for (const APInt &int_value : broadcast_dimensions_attr) {
330     broadcast_dimensions.push_back(int_value.getSExtValue());
331   }
332   if (broadcast_dimensions.size() != shape_small.size()) {
333     return nullptr;
334   }
335   llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
336                                           shape_large.end());
337 
338   // Update according to the broadcast dimensions.
339   for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
340     auto old_value = out_shape[index_pair.value()];
341     auto new_value = shape_small[index_pair.index()];
342     out_shape[index_pair.value()] = std::max(old_value, new_value);
343   }
344   return RankedTensorType::get(out_shape, element_type);
345 }
346 
347 // Deprecated: This is maintained to aid in porting old code that is not yet
348 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
349 // Applies static binary broadcasting to a binary elementwise op.
350 // This is a legacy helper to provide general broadcasting support in legacy,
351 // static shaped code that relies on non-left-padded broadcasting semantics.
352 template <typename BinaryOp>
StaticBinaryBroadcast(Location loc,Value x,Value y,DenseIntElementsAttr broadcast_dims,OpBuilder & builder)353 static Value StaticBinaryBroadcast(Location loc, Value x, Value y,
354                                    DenseIntElementsAttr broadcast_dims,
355                                    OpBuilder &builder) {
356   auto x_type = x.getType().cast<RankedTensorType>();
357   auto y_type = y.getType().cast<RankedTensorType>();
358   auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims);
359   if (!result_type) {
360     emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type
361                    << " with broadcast_dims = " << broadcast_dims;
362     return nullptr;
363   }
364   auto larger_broadcast_dims =
365       GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder);
366   if (x_type.getRank() < y_type.getRank()) {
367     if (x_type != result_type) {
368       x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims);
369     }
370     if (y_type != result_type) {
371       y = builder.create<BroadcastInDimOp>(loc, result_type, y,
372                                            larger_broadcast_dims);
373     }
374   } else {
375     if (x_type != result_type) {
376       x = builder.create<BroadcastInDimOp>(loc, result_type, x,
377                                            larger_broadcast_dims);
378     }
379     if (y_type != result_type) {
380       y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims);
381     }
382   }
383   return builder.create<BinaryOp>(loc, x, y);
384 }
385 
386 // Gets a 1D tensor type suitable for expressing extents of the given tensor
387 // value type. If the value type is ranked, the result will be statically
388 // shaped. Otherwise, it will have a dynamic dimension.
GetExtentsTensorTypeFor(TensorType value_type)389 static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) {
390   Builder b(value_type.getContext());
391   int64_t dim = value_type.hasRank() ? value_type.getRank() : -1;
392   return RankedTensorType::get({dim}, b.getIndexType());
393 }
394 
395 // Given a value (broadcast_to) and a feature dimension, broadcasts a 1D
396 // value (broadcast_from) along that feature dimension. This is a shortcut
397 // for the cases where a 1D tensor must be broadcast along a specific feature
398 // dimension, which can vary based on data layout, etc.
399 //
400 // The extent of `broadcast_from` dim0 must be equal to the extent of the
401 // feature_dim of `broadcast_to`.
402 //
403 // Example:
404 //   [1x2x3x4], [2], 1 -> [1x2x3x4]
405 // TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for
406 // consistency. Possibly also rename for clarity.
Broadcast1DToFeatureDim(Location loc,Value broadcast_to,Value broadcast_from,int64_t feature_dim,OpBuilder & builder)407 static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to,
408                                      Value broadcast_from, int64_t feature_dim,
409                                      OpBuilder &builder) {
410   auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder);
411   auto to_type = broadcast_to.getType().cast<RankedTensorType>();
412   auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
413   auto result_extents_type = GetExtentsTensorTypeFor(to_type);
414   auto result_extents = builder.create<shape::ToExtentTensorOp>(
415       loc, result_extents_type, result_shape);
416   return builder.create<DynamicBroadcastInDimOp>(
417       loc, to_type, broadcast_from, result_extents, broadcast_dims);
418 }
419 
420 // Broadcasts `input` to the shape of `broadcast_to` value following
421 // TF::BroadcastTo semantics.
422 //
423 // Requires that input is a ranked tensor.
424 //
425 // TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp
426 // supports unranked inputs in the lowering.
BroadcastToShapeOf(Location loc,Value input,Value broadcast_to,OpBuilder & builder)427 static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to,
428                                 OpBuilder &builder) {
429   auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
430   auto to_type = broadcast_to.getType().cast<TensorType>();
431   auto result_extents_type = GetExtentsTensorTypeFor(to_type);
432   auto result_extents = builder.create<shape::ToExtentTensorOp>(
433       loc, result_extents_type, result_shape);
434   int64_t rank = input.getType().cast<RankedTensorType>().getRank();
435   auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder);
436   return builder.create<DynamicBroadcastInDimOp>(
437       loc, to_type, input, result_extents, broadcast_dims);
438 }
439 
440 // Creates a batch dot using mhlo::DotGeneralOp.
BatchDot(Location loc,Value lhs,bool transpose_lhs,Value rhs,bool transpose_rhs,int64_t num_batch_dims,ArrayAttr precision_config,OpBuilder * builder)441 Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
442                bool transpose_rhs, int64_t num_batch_dims,
443                ArrayAttr precision_config, OpBuilder *builder) {
444   auto batch_dimensions =
445       llvm::to_vector<4>(llvm::seq<int64_t>(0, num_batch_dims));
446   auto lhs_contracting_dimensions = llvm::to_vector<1>(llvm::makeArrayRef(
447       {transpose_lhs ? num_batch_dims : num_batch_dims + 1}));
448   auto rhs_contracting_dimensions = llvm::to_vector<1>(llvm::makeArrayRef(
449       {transpose_rhs ? num_batch_dims + 1 : num_batch_dims}));
450   auto dimension_numbers = DotDimensionNumbersAttr::get(
451       builder->getContext(),
452       /*lhs_batching_dimensions=*/batch_dimensions,
453       /*rhs_batching_dimensions=*/batch_dimensions,
454       /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
455       /*rhs_contracting_dimensions=*/rhs_contracting_dimensions);
456   auto lhs_shape = lhs.getType().cast<RankedTensorType>().getShape();
457   auto rhs_shape = rhs.getType().cast<RankedTensorType>().getShape();
458   auto shape = llvm::to_vector<4>(lhs_shape);
459   shape[shape.size() - 2] =
460       transpose_lhs ? lhs_shape.back() : lhs_shape[lhs_shape.size() - 2];
461   shape[shape.size() - 1] =
462       transpose_rhs ? rhs_shape[rhs_shape.size() - 2] : rhs_shape.back();
463   Type element_type = getElementTypeOrSelf(lhs.getType());
464   return builder->create<DotGeneralOp>(
465       loc, RankedTensorType::get(shape, element_type), lhs, rhs,
466       dimension_numbers, precision_config);
467 }
468 
469 // Builds a set of operations for applying reduction on the input value. A
470 // tf.sum op is created and will be legalized to tfl ops automatically.
ApplyReduction(Location loc,Value input,DenseIntElementsAttr reduce_dims,OpBuilder * builder)471 static Value ApplyReduction(Location loc, Value input,
472                             DenseIntElementsAttr reduce_dims,
473                             OpBuilder *builder) {
474   auto reduce_dims_op = builder->create<ConstantOp>(loc, reduce_dims);
475   return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
476                                     builder->getBoolAttr(false));
477 }
478 
479 // Creates a mhlo.rng_uniform op with `builder` to generate `num_elements`
480 // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
CreateRngUniform32(Location loc,int num_elements,int lower_limit,int upper_limit,OpBuilder * builder)481 static mhlo::RngOp CreateRngUniform32(Location loc, int num_elements,
482                                       int lower_limit, int upper_limit,
483                                       OpBuilder *builder) {
484   auto shape_tensor = builder->create<mhlo::ConstantOp>(
485       loc, GetI64ElementsAttr({num_elements}, builder));
486 
487   auto lower = builder->create<mhlo::ConstantOp>(
488       loc, builder->getI32IntegerAttr(lower_limit));
489   auto upper = builder->create<mhlo::ConstantOp>(
490       loc, builder->getI32IntegerAttr(upper_limit));
491 
492   return builder->create<mhlo::RngOp>(loc, lower, upper, shape_tensor,
493                                       ::mlir::mhlo::RngDistribution::UNIFORM);
494 }
495 
496 using WhileBodyFnType = llvm::function_ref<void(
497     Location loc, Value iteration, ArrayRef<Value> old_values,
498     SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
499 
500 // Creates a mhlo.while op with `builder` to loop `num_interations` times,
501 // each time calling the given `body_fn` on a set of values to generate a new
502 // set of values. Returns the final set of values via `final_values`. The
503 // initial set of values is passed in via `init_values`.
504 //
505 // This effectively does:
506 //
507 // ```c++
508 // SmallVector<Values, 4> old_values = init_values;
509 // SmallVector<Values, 4> new_values;
510 // for (int i = 0; i < num_iterations; ++i) {
511 //   body_fn(old_values, &new_values, ...);
512 //   old_values = new_values;
513 // }
514 // ```
515 //
516 // Under the hood an induction variable is prepended to values to control the
517 // number of iterations, but that is transparent to `body_fn`, which does not
518 // need to care about that.
CreateWhile32(Location loc,int num_iterations,WhileBodyFnType body_fn,ArrayRef<Value> init_values,SmallVectorImpl<Value> * final_values,OpBuilder * builder)519 static void CreateWhile32(Location loc, int num_iterations,
520                           WhileBodyFnType body_fn, ArrayRef<Value> init_values,
521                           SmallVectorImpl<Value> *final_values,
522                           OpBuilder *builder) {
523   int value_count = init_values.size() + 1;
524 
525   // Prepend a loop induction variable to the initial values.
526   SmallVector<Value, 2> init_values_with_loop_iv;
527   SmallVector<Type, 2> init_types_with_loop_iv;
528   init_values_with_loop_iv.reserve(value_count);
529   init_types_with_loop_iv.reserve(value_count);
530 
531   // The initial value for the loop induction variable is 0.
532   init_values_with_loop_iv.push_back(
533       builder->create<mhlo::ConstantOp>(loc, builder->getI32IntegerAttr(0)));
534   init_values_with_loop_iv.append(init_values.begin(), init_values.end());
535 
536   // Accumulate types of all the init values.
537   for (const auto &init_value_with_loop_iv : init_values_with_loop_iv)
538     init_types_with_loop_iv.push_back(init_value_with_loop_iv.getType());
539 
540   // Create the while op.
541   auto while_op = builder->create<mhlo::WhileOp>(loc, init_types_with_loop_iv,
542                                                  init_values_with_loop_iv);
543   auto ivs_count = init_types_with_loop_iv.size();
544 
545   {
546     OpBuilder::InsertionGuard guard(*builder);
547 
548     // Build up the only block in the condition region.
549     Region &condition = while_op.cond();
550     Block *block = builder->createBlock(&condition);
551     block->addArguments(init_types_with_loop_iv,
552                         SmallVector<Location>(ivs_count, loc));
553 
554     // Get the loop induction variable and compare it against the upper limit.
555     auto loop_iv = block->getArgument(0);
556     auto upper_limit = builder->create<mhlo::ConstantOp>(
557         loc, builder->getI32IntegerAttr(num_iterations));
558     Value compare = builder->create<mhlo::CompareOp>(loc, loop_iv, upper_limit,
559                                                      ComparisonDirection::LT);
560 
561     builder->create<mhlo::ReturnOp>(loc, compare);
562   }
563 
564   {
565     OpBuilder::InsertionGuard guard(*builder);
566 
567     // Build up the only block in the body region.
568     Region &body = while_op.body();
569     Block *block = builder->createBlock(&body);
570     block->addArguments(init_types_with_loop_iv,
571                         SmallVector<Location>(ivs_count, loc));
572 
573     SmallVector<Value, 4> new_values;  // Generated by this iteration
574     new_values.reserve(value_count);
575 
576     // Feed all values excluding the loop induction variable to body_fn.
577     body_fn(loc, block->getArgument(0),
578             ArrayRef<Value>(block->getArguments().begin() + 1,
579                             block->getArguments().end()),
580             &new_values, builder);
581 
582     // Increment the loop induction variable by one.
583     auto one =
584         builder->create<mhlo::ConstantOp>(loc, builder->getI32IntegerAttr(1));
585     auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
586     auto plus_one = builder->create<chlo::BroadcastAddOp>(
587         loc, block->getArgument(0), one, scalar_broadcast_dims);
588     // Prepend with the updated loop induction variable.
589     new_values.insert(new_values.begin(), plus_one);
590 
591     builder->create<mhlo::ReturnOp>(loc, new_values);
592   }
593 
594   // TODO(jpienaar): Support multi-operand while op.
595   final_values->reserve(init_values.size());
596   for (int i = 0, e = init_values.size(); i < e; ++i)
597     final_values->push_back(while_op.getResult(i + 1));
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // BatchNorm op utilities.
602 //===----------------------------------------------------------------------===//
603 
getFeatureDimensionAttr(Builder & b,tensorflow::TensorFormat format,Value input)604 static IntegerAttr getFeatureDimensionAttr(Builder &b,
605                                            tensorflow::TensorFormat format,
606                                            Value input) {
607   return b.getI64IntegerAttr(
608       GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // FFT op utilities.
613 //===----------------------------------------------------------------------===//
614 
615 // Returns the 1D i64 elements attribute populated with the inner-most dim of
616 // the value.
GetInnerDimFromValue(ShapedType type,Builder * builder)617 static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type,
618                                                  Builder *builder) {
619   if (type.getRank() == 0) {
620     return builder->getI64TensorAttr({});
621   }
622   return builder->getI64TensorAttr(type.getShape().back());
623 }
624 
625 // Returns True if the inner-most dim is static.
CheckInnerDimStatic(ShapedType type,Builder * builder)626 bool CheckInnerDimStatic(ShapedType type, Builder *builder) {
627   if (!type.hasRank()) {
628     return false;
629   }
630   return !type.isDynamicDim(type.getShape().size() - 1);
631 }
632 
633 //===----------------------------------------------------------------------===//
634 // MatMul op utilities.
635 //===----------------------------------------------------------------------===//
636 
637 // If the 'transpose' attribute is true returns ElementsAttr to transpose 2D
638 // matrix. Otherwise, returns ElementsAttr for identity transpose.
Get2DTransposePerm(BoolAttr transpose,Builder * b)639 static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
640   if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b);
641   return GetI64ElementsAttr({0, 1}, b);
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // MatrixBandPart op utilities.
646 //===----------------------------------------------------------------------===//
647 
648 // Gets the size of the dimension `dim_from_end` from the end of `input`.
649 // Requires that `input` is a tensor.
GetDimensionSizeFromEnd(Value input,int dim_from_end)650 static int GetDimensionSizeFromEnd(Value input, int dim_from_end) {
651   // Note: the verifier enforces that `input` is a ranked tensor.
652   auto input_type = input.getType().cast<TensorType>();
653   auto input_shape = input_type.getShape();
654   int dim = (input_shape.size() - 1) - dim_from_end;
655   return input_shape[dim];
656 }
657 
658 // Gets a 2D tensor type with shape {dim_0, dim_1}, where `dim_0` and `dim_1`
659 // have the same size as the last two dimensions of `input` (the second-to-last
660 // dimension and last dimension, respectively). The element type of the
661 // outputted RankedTensorType will match the element type of `input`.
662 // Requires that `input` is a tensor.
Get2DTensorType(Value input,Value num_lower)663 static RankedTensorType Get2DTensorType(Value input, Value num_lower) {
664   // `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last.
665   int dim_0 = GetDimensionSizeFromEnd(input, 1);
666   int dim_1 = GetDimensionSizeFromEnd(input, 0);
667   auto element_type = num_lower.getType().cast<TensorType>().getElementType();
668   return RankedTensorType::get({dim_0, dim_1}, element_type);
669 }
670 
671 // Creates a HLO ConvertOp, converting `input` to have the same element type as
672 // `elem_type_tensor`. Requires `elem_type_tensor` to be a tensor.
CreateConvertOp(OpBuilder * builder,Location loc,Value input,Value elem_type_tensor)673 static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input,
674                              Value elem_type_tensor) {
675   auto element_type =
676       elem_type_tensor.getType().cast<TensorType>().getElementType();
677   return builder->create<mhlo::ConvertOp>(loc, input, element_type);
678 }
679 
680 //===----------------------------------------------------------------------===//
681 // Pad op utilities.
682 //===----------------------------------------------------------------------===//
683 
684 // Slices input attribute of rank two and returns the specified column.
685 //
686 // Always returns 64 bit integer attribute regardless of bitwidth of the input
687 // attribute.
SliceDenseIntElementsAttrColumn2D(ElementsAttr input,int column)688 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
689     ElementsAttr input, int column) {
690   auto int_attr = input.cast<DenseIntElementsAttr>();
691   auto shaped_type = int_attr.getType();
692   auto shape = shaped_type.getShape();
693 
694   if (shape.size() != 2) return DenseIntElementsAttr();
695 
696   llvm::SmallVector<int64_t, 4> values;
697   values.reserve(shaped_type.getNumElements() / shape[1]);
698 
699   for (auto it : llvm::enumerate(int_attr.getValues<APInt>())) {
700     if (static_cast<int>(it.index() % shape[1]) == column) {
701       values.push_back(it.value().getSExtValue());
702     }
703   }
704 
705   auto element_type = IntegerType::get(input.getContext(), 64);
706   return DenseIntElementsAttr::get(
707       RankedTensorType::get({shape[0]}, element_type), values);
708 }
709 
710 // Returns interior padding to use in HLO Pad op based on the TensorFlow padding
711 // in TensorFlow PadV2 op.
GetInteriorPadding(ElementsAttr tf_padding)712 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
713   auto length = tf_padding.getType().getShape()[0];
714   auto element_type = IntegerType::get(tf_padding.getContext(), 64);
715   return DenseIntElementsAttr::get<int64_t>(
716       RankedTensorType::get({length}, element_type), 0);
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // Binary op utilities.
721 //===----------------------------------------------------------------------===//
722 
723 // Returns whether the two values are guaranteed to be broadcastable to the
724 // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
725 // must be broadcasted with a size 1 tensor or another dynamic dimension.
726 // Returns false on rankless.
AreBroadcastCompatible(Value x,Value y)727 static bool AreBroadcastCompatible(Value x, Value y) {
728   auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
729   auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
730   if (!x_rankless || !y_rankless) {
731     return false;
732   }
733 
734   // Check that the shapes can be broadcasted.
735   auto shape_x = x_rankless.getShape();
736   auto shape_y = y_rankless.getShape();
737 
738   int rank_diff = shape_x.size() - shape_y.size();
739   int offset_x = rank_diff > 0 ? rank_diff : 0;
740   int offset_y = rank_diff < 0 ? -rank_diff : 0;
741   for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) {
742     int index_x = i + offset_x;
743     int index_y = i + offset_y;
744     if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) ||
745         (shape_y[index_y] == -1 && shape_x[index_x] != 1)) {
746       return false;
747     }
748   }
749 
750   return true;
751 }
752 
753 // Return a new TensorType the same rank and dimensions as the input with an
754 // updated element type.
ChangeTensorElementType(Builder * b,Type tensor_type,Type element_type)755 static Type ChangeTensorElementType(Builder *b, Type tensor_type,
756                                     Type element_type) {
757   RankedTensorType ranked_type = tensor_type.dyn_cast<RankedTensorType>();
758   if (ranked_type) {
759     return RankedTensorType::get(ranked_type.getShape(), element_type);
760   }
761 
762   return UnrankedTensorType::get(element_type);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // Softmax op utilities.
767 //===----------------------------------------------------------------------===//
768 
769 // Returns the type to use for accumulating the given type.
GetAccumulationType(Type ty)770 static Type GetAccumulationType(Type ty) {
771   // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
772   // repeated floating point additions.
773   return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // Softplus op utilities.
778 //===----------------------------------------------------------------------===//
779 
GetEpsilonValue(Type ty)780 static DenseElementsAttr GetEpsilonValue(Type ty) {
781   auto element_ty = ty.cast<TensorType>().getElementType();
782   auto scalar_ty = RankedTensorType::get({}, element_ty);
783   if (element_ty.isF16()) {
784     uint16_t raw_epsilon = Eigen::numext::bit_cast<uint16_t>(
785         Eigen::NumTraits<Eigen::half>::epsilon());
786     auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
787     return DenseElementsAttr::get(scalar_ty, value);
788   } else if (element_ty.isBF16()) {
789     uint16_t raw_epsilon = Eigen::numext::bit_cast<uint16_t>(
790         Eigen::NumTraits<Eigen::bfloat16>::epsilon());
791     auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
792     return DenseElementsAttr::get(scalar_ty, value);
793   } else if (element_ty.isF32()) {
794     auto value = APFloat(std::numeric_limits<float>::epsilon());
795     return DenseElementsAttr::get(scalar_ty, value);
796   } else if (element_ty.isF64()) {
797     auto value = APFloat(std::numeric_limits<double>::epsilon());
798     return DenseElementsAttr::get(scalar_ty, value);
799   }
800   llvm_unreachable("unsupported element type for tf.SoftPlus");
801 }
802 
803 //===----------------------------------------------------------------------===//
804 // ArgMax/ArgMin op utilities.
805 //===----------------------------------------------------------------------===//
806 
BuildArgMinMaxReductionBody(Type input_element_type,Type index_element_type,ComparisonDirection direction,Region * body,OpBuilder * builder)807 static void BuildArgMinMaxReductionBody(Type input_element_type,
808                                         Type index_element_type,
809                                         ComparisonDirection direction,
810                                         Region *body, OpBuilder *builder) {
811   OpBuilder::InsertionGuard insertion_point_gurad(*builder);
812 
813   Type input_type = RankedTensorType::get(/*shape=*/{}, input_element_type);
814   Type index_type = RankedTensorType::get(/*shape=*/{}, index_element_type);
815   Block *block = builder->createBlock(body);
816   Location loc = body->getLoc();
817   block->addArguments({input_type, index_type, input_type, index_type},
818                       SmallVector<Location, 4>(4, loc));
819 
820   Value lhs_val = block->getArgument(0);
821   Value lhs_index = block->getArgument(1);
822   Value rhs_val = block->getArgument(2);
823   Value rhs_index = block->getArgument(3);
824 
825   ImplicitLocOpBuilder b(loc, *builder);
826   Value compare_dt = b.create<CompareOp>(lhs_val, rhs_val, direction);
827   Value selected_input =
828       b.create<SelectOp>(input_type, compare_dt, lhs_val, rhs_val);
829 
830   Value compare_eq =
831       b.create<CompareOp>(lhs_val, rhs_val, ComparisonDirection::EQ);
832   Value min_index = b.create<MinOp>(lhs_index, rhs_index);
833   Value min_val_index =
834       b.create<SelectOp>(index_type, compare_dt, lhs_index, rhs_index);
835   Value selected_index =
836       b.create<SelectOp>(index_type, compare_eq, min_index, min_val_index);
837 
838   Value return_values[] = {selected_input, selected_index};
839   b.create<ReturnOp>(return_values);
840 }
841 
842 //===----------------------------------------------------------------------===//
843 // PartitionedCall op utilities.
844 //===----------------------------------------------------------------------===//
845 
846 // Verify that the arguments to be passed into the function are the same types
847 // as the function paramter types.
ArgTypesMatchCallee(mlir::Operation * op,OperandRange args,SymbolRefAttr func)848 static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args,
849                                 SymbolRefAttr func) {
850   auto module = op->getParentOfType<ModuleOp>();
851   auto function =
852       dyn_cast_or_null<func::FuncOp>(SymbolTable::lookupSymbolIn(module, func));
853   FunctionType function_ty = function.getFunctionType();
854 
855   for (auto arg_in : llvm::zip(args, function_ty.getInputs())) {
856     if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) {
857       // Argument type and input type mismatch.
858       return false;
859     }
860   }
861   return true;
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // Slice op utilities.
866 //===----------------------------------------------------------------------===//
867 
CanBeTranslatedToDynamicSlice(Value input,Value start_indices,DenseIntElementsAttr slice_sizes)868 static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
869                                           DenseIntElementsAttr slice_sizes) {
870   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
871   if (!input_ty) return false;
872   auto start_indices_ty = start_indices.getType().dyn_cast<RankedTensorType>();
873   if (!start_indices_ty) return false;
874 
875   int64_t input_rank = input_ty.getRank();
876   ArrayRef<int64_t> input_shape = input_ty.getShape();
877   DenseIntElementsAttr constant_start_indices;
878   bool is_constant_start =
879       matchPattern(start_indices, m_Constant(&constant_start_indices));
880 
881   for (int64_t i = 0; i < input_rank; ++i) {
882     int64_t input_size = input_shape[i];
883     int64_t slice_size = slice_sizes.getValues<IntegerAttr>()[i].getInt();
884     // A slice_size of -1 means "all elements from start_index to the end".
885     // In order to support these semantics, we need to know both the start index
886     // and the shape of the input dimension.
887     if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false;
888   }
889   return true;
890 }
891 
892 // TF slice size can be -1, which represents all elements from start_index to
893 // the end. HLO slice size can't be -1. As such, we need to translate TF slice
894 // size -1 to HLO slice size.
TFSliceSizes2HLOSliceSizes(Value input,Value start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)895 static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
896     Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
897     Builder *builder) {
898   DenseIntElementsAttr constant_start_indices;
899   if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
900     return hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))
901         .cast<DenseIntElementsAttr>();
902   }
903 
904   auto input_ty = input.getType().dyn_cast<RankedTensorType>();
905   int64_t input_rank = input_ty.getRank();
906   ArrayRef<int64_t> input_shape = input_ty.getShape();
907   SmallVector<int64_t, 4> normalized_sizes;
908 
909   for (int64_t i = 0; i < input_rank; ++i) {
910     int64_t input_size = input_shape[i];
911     int64_t start_index =
912         constant_start_indices.getValues<IntegerAttr>()[i].getInt();
913     int64_t slice_size = slice_sizes.getValues<IntegerAttr>()[i].getInt();
914     normalized_sizes.push_back(slice_size == -1 ? input_size - start_index
915                                                 : slice_size);
916   }
917 
918   return GetI64ElementsAttr(normalized_sizes, builder);
919 }
920 
921 //===----------------------------------------------------------------------===//
922 // XlaGather op utilities.
923 //===----------------------------------------------------------------------===//
924 
HasValidGatherDims(StringAttr attr)925 bool HasValidGatherDims(StringAttr attr) {
926   ::xla::GatherDimensionNumbers dims;
927   return dims.ParseFromString(attr.getValue().str());
928 }
929 
GetGatherDimNumsAttr(StringAttr attr,Builder * builder)930 GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr,
931                                                 Builder *builder) {
932   ::xla::GatherDimensionNumbers dims;
933   if (!dims.ParseFromString(attr.getValue().str())) return {};
934   return ::xla::ConvertGatherDimensionNumbers(dims, builder);
935 }
936 
937 //===----------------------------------------------------------------------===//
938 // XlaDot op utilities.
939 //===----------------------------------------------------------------------===//
940 
HasValidDotDims(StringAttr attr)941 bool HasValidDotDims(StringAttr attr) {
942   ::xla::DotDimensionNumbers dims;
943   return dims.ParseFromString(attr.getValue().str());
944 }
945 
GetDotDimNumsAttr(StringAttr attr,Builder * builder)946 DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, Builder *builder) {
947   ::xla::DotDimensionNumbers dims;
948   if (!dims.ParseFromString(attr.getValue().str())) return {};
949   return ::xla::ConvertDotDimensionNumbers(dims, builder);
950 }
951 
HasValidPrecisionConfig(StringAttr attr)952 bool HasValidPrecisionConfig(StringAttr attr) {
953   ::xla::PrecisionConfig precision;
954   return precision.ParseFromString(attr.getValue().str());
955 }
956 
GetPrecisionConfigAttr(StringAttr attr,Builder * builder)957 mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) {
958   ::xla::PrecisionConfig precision;
959   if (!precision.ParseFromString(attr.getValue().str())) return {};
960   return ::xla::ConvertPrecisionConfig(&precision, builder);
961 }
962 
963 //===----------------------------------------------------------------------===//
964 // XlaVariadicReduceV2 op utilities.
965 //===----------------------------------------------------------------------===//
966 
BuildBodyWithCall(PatternRewriter & rewriter,const Location & loc,mlir::SymbolRefAttr func,mlir::FunctionType func_ty,Region * body)967 static void BuildBodyWithCall(PatternRewriter &rewriter, const Location &loc,
968                               mlir::SymbolRefAttr func,
969                               mlir::FunctionType func_ty, Region *body) {
970   OpBuilder::InsertionGuard guard(rewriter);
971 
972   Block *block = rewriter.createBlock(body);
973   auto inputs = func_ty.getInputs();
974   block->addArguments(inputs, SmallVector<Location>(inputs.size(), loc));
975   mlir::func::CallOp call_op = rewriter.create<mlir::func::CallOp>(
976       loc, func, func_ty.getResults(), block->getArguments());
977   rewriter.create<mhlo::ReturnOp>(loc, call_op.getResults());
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // Op converters.
982 //===----------------------------------------------------------------------===//
983 
GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dims,tensorflow::TensorFormat format,Builder * builder)984 NamedAttribute GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dims,
985                                            tensorflow::TensorFormat format,
986                                            Builder *builder) {
987   int64_t num_spatial_dims = spatial_dims.size();
988   int64_t num_dims = num_spatial_dims + 2;
989 
990   int64_t batch_dim = GetTensorBatchDimIndex(num_dims, format);
991   int64_t feature_dim = GetTensorFeatureDimIndex(num_dims, format);
992 
993   // Filters data_format is always HWIO so input channels dimension is after
994   // all spatial dimensions.
995   int64_t kernel_input_feature_dim = num_spatial_dims;
996   int64_t kernel_output_feature_dim = num_spatial_dims + 1;
997   SmallVector<int64_t, 4> kernel_spatial_dimensions;
998   kernel_spatial_dimensions.resize(num_spatial_dims);
999   std::iota(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(),
1000             0);
1001 
1002   return builder->getNamedAttr(
1003       "dimension_numbers",
1004       ConvDimensionNumbersAttr::get(
1005           builder->getContext(), batch_dim, feature_dim, spatial_dims,
1006           kernel_input_feature_dim, kernel_output_feature_dim,
1007           kernel_spatial_dimensions, batch_dim, feature_dim, spatial_dims));
1008 }
1009 
1010 // Converts a TF::BiasAddOp to HLO.
1011 // This differs from a normal TF::AddOp with respect to how the data_format
1012 // is handled, which can optionally require a general broadcast of the
1013 // 'bias' term in a way that is not compatible with the standard left-padded
1014 // broadcast semantics (i.e. NCHW will broadcast into dimension 1).
1015 // The correct 'bias' broadcast will be synthesized manually.
1016 class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
1017  public:
1018   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::BiasAddOp op,PatternRewriter & rewriter) const1019   LogicalResult matchAndRewrite(TF::BiasAddOp op,
1020                                 PatternRewriter &rewriter) const override {
1021     Location loc = op.getLoc();
1022     tensorflow::TensorFormat data_format;
1023     if (!FormatFromString(op.data_format().str(), &data_format))
1024       return op.emitOpError("invalid data format");
1025 
1026     auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
1027     if (!value_type) return failure();
1028     auto feature_dim = GetFeatureDimension(data_format, value_type);
1029     auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
1030                                                   feature_dim, rewriter);
1031     Value add = rewriter.create<AddOp>(loc, op.value(), bias_broadcast);
1032     if (add.getType() != op.getType()) {
1033       add = rewriter.create<tensor::CastOp>(loc, op.getType(), add);
1034     }
1035     rewriter.replaceOp(op, {add});
1036     return success();
1037   }
1038 };
1039 
1040 // Conterts tf.Conv2D to mhlo.dynamic_conv.
1041 // TODO(disc): To recover static special case's performance with adding folding,
1042 // canonicalization func and removing ConvertConvOp.
1043 template <typename OpT, int num_spatial_dims, bool depthwise_conv = false>
1044 class ConvertConvDynamic : public OpRewritePattern<OpT> {
1045  public:
1046   using OpRewritePattern<OpT>::OpRewritePattern;
1047 
GetPaddingValues(OpT & op,PatternRewriter & rewriter,Value input_size,Value filter_size,int64_t dilation_rate,int64_t stride,tensorflow::Padding padding_type,Type shape_scalar_type,Value * padding_low,Value * padding_high) const1048   bool GetPaddingValues(OpT &op, PatternRewriter &rewriter, Value input_size,
1049                         Value filter_size, int64_t dilation_rate,
1050                         int64_t stride, tensorflow::Padding padding_type,
1051                         Type shape_scalar_type, Value *padding_low,
1052                         Value *padding_high) const {
1053     // Stride must be > 0
1054     if (stride <= 0) return false;
1055     // Dilation rate must be >= 1
1056     if (dilation_rate < 1) return false;
1057 
1058     Location loc = op.getLoc();
1059     switch (padding_type) {
1060       case tensorflow::Padding::VALID: {
1061         auto zero =
1062             rewriter.create<arith::ConstantIntOp>(loc, 0, shape_scalar_type);
1063         *padding_low = *padding_high = zero;
1064         break;
1065       }
1066       case tensorflow::Padding::EXPLICIT:
1067         break;
1068       case tensorflow::Padding::SAME: {
1069         auto zero =
1070             rewriter.create<arith::ConstantIntOp>(loc, 0, shape_scalar_type);
1071         auto one =
1072             rewriter.create<arith::ConstantIntOp>(loc, 1, shape_scalar_type);
1073         auto two =
1074             rewriter.create<arith::ConstantIntOp>(loc, 2, shape_scalar_type);
1075         // See also the parallel implementation in
1076         // GetWindowedOutputSizeFromDimsV2. effective_filter_size = (filter_size
1077         // - 1) * dilation_rate + 1
1078         Value stride_value = rewriter.create<arith::ConstantIntOp>(
1079             loc, stride, shape_scalar_type);
1080         Value dilation_rate_value = rewriter.create<arith::ConstantIntOp>(
1081             loc, dilation_rate, shape_scalar_type);
1082         Value effective_filter_size_op = rewriter.create<arith::AddIOp>(
1083             loc, one,
1084             rewriter.create<arith::MulIOp>(
1085                 loc, dilation_rate_value,
1086                 rewriter.create<arith::SubIOp>(loc, filter_size, one)));
1087         // output_size = (input_size + stride - 1) / stride;
1088         Value output_size = rewriter.create<arith::DivUIOp>(
1089             loc,
1090             rewriter.create<arith::AddIOp>(
1091                 loc, input_size,
1092                 rewriter.create<arith::SubIOp>(loc, stride_value, one)),
1093             stride_value);
1094         // std::max(int64{0}, (output_size - 1) * stride +
1095         //     effective_filter_size - input_size);
1096         Value padding_needed = rewriter.create<arith::SubIOp>(
1097             loc,
1098             rewriter.create<arith::AddIOp>(
1099                 loc, effective_filter_size_op,
1100                 rewriter.create<arith::MulIOp>(
1101                     loc, stride_value,
1102                     rewriter.create<arith::SubIOp>(loc, output_size, one))),
1103             input_size);
1104         Value cond = rewriter.create<mlir::arith::CmpIOp>(
1105             loc, arith::CmpIPredicate::sge, padding_needed, zero);
1106         padding_needed = rewriter.create<mlir::arith::SelectOp>(
1107             loc, padding_needed.getType(), cond, padding_needed, zero);
1108         *padding_low =
1109             rewriter.create<arith::DivUIOp>(loc, padding_needed, two);
1110         *padding_high =
1111             rewriter.create<arith::SubIOp>(loc, padding_needed, *padding_low);
1112         break;
1113       }
1114     }
1115     return true;
1116   }
1117 
matchAndRewriteDynamicConv(OpT op,PatternRewriter & rewriter) const1118   LogicalResult matchAndRewriteDynamicConv(OpT op,
1119                                            PatternRewriter &rewriter) const {
1120     tensorflow::TensorFormat data_format;
1121     if (!FormatFromString(op.data_format().str(), &data_format))
1122       return op.emitOpError("invalid data format");
1123 
1124     tensorflow::Padding padding;
1125     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1126       return failure();
1127 
1128     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1129     auto filter_ty =
1130         op.filter().getType().template dyn_cast<RankedTensorType>();
1131     auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
1132     if (!input_ty || !filter_ty || !result_ty) return failure();
1133     // TODO(disc): Remove this constraint once fold and canonicalization
1134     // implemented.
1135     if (input_ty.hasStaticShape() && filter_ty.hasStaticShape())
1136       return failure();
1137 
1138     ArrayRef<Attribute> dilations = op.dilations().getValue();
1139     ArrayRef<Attribute> strides = op.strides().getValue();
1140     ArrayRef<Attribute> explicit_paddings;
1141     if (padding == tensorflow::Padding::EXPLICIT) {
1142       // EXPLICIT padding mode and the associated attribute is attached to
1143       // Conv2D.
1144       explicit_paddings =
1145           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1146     }
1147 
1148     SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1149     SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1150     SmallVector<int64_t, num_spatial_dims> window_strides;
1151     SmallVector<Value, num_spatial_dims * 2> paddings;
1152 
1153     auto get_int = [](Attribute attr) {
1154       return attr.template cast<IntegerAttr>().getInt();
1155     };
1156 
1157     constexpr int num_dims = num_spatial_dims + 2;
1158 
1159     Location loc = op.getLoc();
1160     auto shape_scalar_type = rewriter.getIntegerType(32);
1161 
1162     auto get_const = [&](int64_t val) {
1163       return rewriter.create<mlir::arith::ConstantIntOp>(loc, val,
1164                                                          shape_scalar_type);
1165     };
1166     auto get_dim_value = [&](Value val, int64_t dim) {
1167       Value dim_value = rewriter.create<tensor::DimOp>(loc, val, dim);
1168       return rewriter.create<arith::IndexCastOp>(loc, shape_scalar_type,
1169                                                  dim_value);
1170     };
1171 
1172     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1173       const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1174       spatial_dim_indices.push_back(dim);
1175 
1176       const int64_t dilation = get_int(dilations[dim]);
1177       rhs_dilations.push_back(dilation);
1178       const int64_t stride = get_int(strides[dim]);
1179       window_strides.push_back(stride);
1180 
1181       Value pad_low, pad_high;
1182       if (padding == tensorflow::Padding::EXPLICIT) {
1183         pad_low = get_const(get_int(explicit_paddings[2 * dim]));
1184         pad_high = get_const(get_int(explicit_paddings[2 * dim + 1]));
1185       } else {
1186         auto input_size = get_dim_value(op.input(), dim);
1187         auto filter_size = get_dim_value(op.filter(), i);
1188         if (!GetPaddingValues(op, rewriter, input_size, filter_size, dilation,
1189                               stride, padding, shape_scalar_type, &pad_low,
1190                               &pad_high)) {
1191           return failure();
1192         }
1193       }
1194       paddings.push_back(pad_low);
1195       paddings.push_back(pad_high);
1196     }
1197     auto rhs_dilations_attr = rewriter.getNamedAttr(
1198         "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1199 
1200     auto window_strides_attr = rewriter.getNamedAttr(
1201         "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1202 
1203     auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1204         spatial_dim_indices, data_format, &rewriter);
1205 
1206     const int64_t input_channels =
1207         GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1208     // Filters data_format is always HWIO so input channels dimension is after
1209     // all spatial dimensions.
1210     const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1211     // TensorFlow convolution op verifies that the number of input channels is
1212     // divisible by the number of filter channels.
1213     // For depthwise convolution the feature_group_count argument would be set
1214     // to the input feature dimension.
1215     const int64_t feature_group_count =
1216         depthwise_conv ? input_channels : input_channels / filter_channels;
1217     auto feature_group_count_attr = rewriter.getNamedAttr(
1218         "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1219 
1220     auto batch_group_count_attr = rewriter.getNamedAttr(
1221         "batch_group_count", rewriter.getI64IntegerAttr(1));
1222 
1223     Value paddings_op = rewriter.create<tensor::FromElementsOp>(
1224         op.getLoc(),
1225         RankedTensorType::get(2 * num_spatial_dims, rewriter.getI32Type()),
1226         paddings);
1227 
1228     SmallVector<Value, 3> operands(op.getOperands());
1229     operands.push_back(paddings_op);
1230     // Reshape the filter to {spatial_dims...., 1,in_channels *
1231     // channel_multiplier}
1232     if (depthwise_conv) {
1233       ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1234       llvm::SmallVector<int64_t, num_dims> new_shape(
1235           filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1236       new_shape.push_back(1);
1237       new_shape.push_back(filter_shape[num_spatial_dims] *
1238                           filter_shape[num_spatial_dims + 1]);
1239       operands[1] = rewriter.create<mhlo::ReshapeOp>(
1240           op.getLoc(),
1241           RankedTensorType::get(new_shape, filter_ty.getElementType()),
1242           operands[1]);
1243     }
1244     NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
1245                               dimension_numbers_attr, feature_group_count_attr,
1246                               batch_group_count_attr};
1247     rewriter.replaceOpWithNewOp<mhlo::DynamicConvOp>(op, op.getType(), operands,
1248                                                      llvm::makeArrayRef(attrs));
1249     return success();
1250   }
1251 
matchAndRewrite(OpT op,PatternRewriter & rewriter) const1252   LogicalResult matchAndRewrite(OpT op,
1253                                 PatternRewriter &rewriter) const override {
1254     return matchAndRewriteDynamicConv(op, rewriter);
1255   }
1256 };
1257 
1258 using ConvertConv2DDynamic =
1259     ConvertConvDynamic<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1260 
1261 // Converts the TensorFlow conv op in template to the generic HLO conv op by
1262 // converting TensorFlow op attributes to HLO op attributes.
1263 //
1264 // Sample result for Conv2D:
1265 //
1266 //   %conv = "mhlo.convolution"(%input, %filter) {
1267 //     strides = [1, 2],
1268 //     paddings = [[1, 0], [1, 1]],
1269 //     ...
1270 //   }
1271 //
1272 // This pattern is not defined using declarative rewrite rules as computation of
1273 // the paddings attribute anyway requires multiple source op attributes and
1274 // result op attributes. Defining it as declarative rewrite rule will introduce
1275 // some duplication in the C++ helper methods.
1276 template <typename OpTy, int num_spatial_dims, bool depthwise_conv = false>
1277 class ConvertConvOp : public OpRewritePattern<OpTy> {
1278  public:
1279   using OpRewritePattern<OpTy>::OpRewritePattern;
1280 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1281   LogicalResult matchAndRewrite(OpTy op,
1282                                 PatternRewriter &rewriter) const override {
1283     tensorflow::TensorFormat data_format;
1284     if (!FormatFromString(op.data_format().str(), &data_format))
1285       return op.emitOpError("invalid data format");
1286 
1287     tensorflow::Padding padding;
1288     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1289       return failure();
1290 
1291     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1292     auto filter_ty =
1293         op.filter().getType().template dyn_cast<RankedTensorType>();
1294 
1295     // With the exception of input's batch dimension, input and filter need to
1296     // have static shape for calculation of HLO paddings and feature group count
1297     // attributes. Filter is validated here, input is mostly validated at use.
1298     if (!input_ty || !filter_ty || !filter_ty.hasStaticShape())
1299       return failure();
1300 
1301     ArrayRef<Attribute> dilations = op.dilations().getValue();
1302     ArrayRef<Attribute> strides = op.strides().getValue();
1303     ArrayRef<Attribute> explicit_paddings;
1304     if (padding == tensorflow::Padding::EXPLICIT) {
1305       // EXPLICIT padding mode and the associated attribute is limited to
1306       // Conv2D. So, fetch attribute by identifier instead of the
1307       // op.explicit_paddings() attribute getter.
1308       explicit_paddings =
1309           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1310     }
1311 
1312     SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1313     SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1314     SmallVector<int64_t, num_spatial_dims> window_strides;
1315     SmallVector<int64_t, num_spatial_dims * 2> paddings;
1316 
1317     auto get_int = [](Attribute attr) {
1318       return attr.template cast<IntegerAttr>().getInt();
1319     };
1320 
1321     constexpr int num_dims = num_spatial_dims + 2;
1322     for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1323       const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1324       spatial_dim_indices.push_back(dim);
1325 
1326       const int64_t dilation = get_int(dilations[dim]);
1327       rhs_dilations.push_back(dilation);
1328       const int64_t stride = get_int(strides[dim]);
1329       window_strides.push_back(stride);
1330 
1331       int64_t pad_low, pad_high;
1332       if (padding == tensorflow::Padding::EXPLICIT) {
1333         pad_low = get_int(explicit_paddings[2 * dim]);
1334         pad_high = get_int(explicit_paddings[2 * dim + 1]);
1335       } else {
1336         int64_t output_size;
1337         int64_t pad_low_int64;
1338         int64_t pad_high_int64;
1339         int64_t input_size = input_ty.getDimSize(dim);
1340         if (input_size == ShapedType::kDynamicSize) return failure();
1341         tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1342             input_size, filter_ty.getDimSize(i), dilation, stride, padding,
1343             &output_size, &pad_low_int64, &pad_high_int64);
1344         if (!status.ok()) return failure();
1345         pad_low = pad_low_int64;
1346         pad_high = pad_high_int64;
1347       }
1348       paddings.push_back(pad_low);
1349       paddings.push_back(pad_high);
1350     }
1351 
1352     auto rhs_dilations_attr = rewriter.getNamedAttr(
1353         "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1354 
1355     auto window_strides_attr = rewriter.getNamedAttr(
1356         "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1357 
1358     auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1359         spatial_dim_indices, data_format, &rewriter);
1360 
1361     const int64_t input_channels =
1362         GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1363     if (input_channels == ShapedType::kDynamicSize) return failure();
1364     // Filters data_format is always HWIO so input channels dimension is after
1365     // all spatial dimensions.
1366     const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1367     // TensorFlow convolution op verifies that the number of input channels is
1368     // divisible by the number of filter channels.
1369     // For depthwise convolution the feature_group_count argument would be set
1370     // to the input feature dimension.
1371     const int64_t feature_group_count =
1372         depthwise_conv ? input_channels : input_channels / filter_channels;
1373     auto feature_group_count_attr = rewriter.getNamedAttr(
1374         "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1375 
1376     auto batch_group_count_attr = rewriter.getNamedAttr(
1377         "batch_group_count", rewriter.getI64IntegerAttr(1));
1378 
1379     RankedTensorType paddings_ty = RankedTensorType::get(
1380         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
1381     auto paddings_attr = rewriter.getNamedAttr(
1382         "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
1383 
1384     SmallVector<Value, 2> operands(op.getOperands());
1385     // Reshape the filter to {spatial_dims...., 1,in_channels *
1386     // channel_multiplier}
1387     if (depthwise_conv) {
1388       ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1389       llvm::SmallVector<int64_t, num_dims> new_shape(
1390           filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1391       new_shape.push_back(1);
1392       new_shape.push_back(filter_shape[num_spatial_dims] *
1393                           filter_shape[num_spatial_dims + 1]);
1394       operands[1] = rewriter.create<mhlo::ReshapeOp>(
1395           op.getLoc(),
1396           RankedTensorType::get(new_shape, filter_ty.getElementType()),
1397           operands[1]);
1398     }
1399     NamedAttribute attrs[] = {rhs_dilations_attr,     window_strides_attr,
1400                               dimension_numbers_attr, feature_group_count_attr,
1401                               batch_group_count_attr, paddings_attr};
1402     rewriter.replaceOpWithNewOp<ConvolutionOp>(op, op.getType(), operands,
1403                                                llvm::makeArrayRef(attrs));
1404     return success();
1405   }
1406 };
1407 
1408 using ConvertConv2DOp = ConvertConvOp<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1409 using ConvertConv3DOp = ConvertConvOp<TF::Conv3DOp, /*num_spatial_dims=*/3>;
1410 using ConvertDepthConv2DOp =
1411     ConvertConvOp<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
1412                   /*depthwise_conv=*/true>;
1413 
1414 // Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const.
1415 class ConvertPadOpDynamic : public OpRewritePattern<TF::PadV2Op> {
1416  public:
1417   using OpRewritePattern::OpRewritePattern;
1418   // TODO(disc): To recover static special case's performance with folding and
1419   // canonicalization.
matchAndRewrite(TF::PadV2Op op,PatternRewriter & rewriter) const1420   LogicalResult matchAndRewrite(TF::PadV2Op op,
1421                                 PatternRewriter &rewriter) const override {
1422     Location loc = op.getLoc();
1423     auto input = op.input();
1424     auto paddings = op.paddings();
1425     auto constant_values = op.constant_values();
1426     auto input_type = input.getType().dyn_cast<RankedTensorType>();
1427     auto paddings_type = paddings.getType().dyn_cast<RankedTensorType>();
1428     if (!input_type || !paddings_type || !paddings_type.hasStaticShape())
1429       return failure();
1430 
1431     // TODO(disc): Remove this constraint once fold and canonicalization is
1432     // implemented.
1433     if (input_type.hasStaticShape()) return failure();
1434 
1435     int input_rank = input_type.getRank();
1436     // interior padding
1437     std::vector<int64_t> interior_values(input_rank, 0);
1438     auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter);
1439 
1440     Value interior_padding_tensor =
1441         rewriter.create<mhlo::ConstantOp>(loc, interior_attr);
1442     Type paddings_elem_ty = paddings_type.getElementType();
1443     if (!paddings_elem_ty.isInteger(64)) {
1444       interior_padding_tensor = rewriter.create<mhlo::ConvertOp>(
1445           loc, interior_padding_tensor, paddings_elem_ty);
1446     }
1447     llvm::SmallVector<int64_t, 2> transposed_shape = {2, input_rank};
1448     auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter);
1449     Value transposed_paddings =
1450         rewriter.create<mhlo::TransposeOp>(loc, paddings, transpose_attr);
1451     Value reshaped_paddings = rewriter.create<mhlo::ReshapeOp>(
1452         loc, RankedTensorType::get({input_rank * 2}, paddings_elem_ty),
1453         transposed_paddings);
1454 
1455     auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter);
1456     auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter);
1457     auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter);
1458     Value left_padding_tensor = rewriter.create<mhlo::SliceOp>(
1459         loc, reshaped_paddings, left_padding_start_attr,
1460         left_padding_limit_attr, left_padding_stride_attr);
1461 
1462     auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter);
1463     auto right_padding_limit_attr =
1464         GetI64ElementsAttr({2 * input_rank}, &rewriter);
1465     auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter);
1466     Value right_padding_tensor = rewriter.create<mhlo::SliceOp>(
1467         loc, reshaped_paddings, right_padding_start_attr,
1468         right_padding_limit_attr, right_padding_stride_attr);
1469 
1470     rewriter.replaceOpWithNewOp<mhlo::DynamicPadOp>(
1471         op, op.getType(), input, constant_values, left_padding_tensor,
1472         right_padding_tensor, interior_padding_tensor);
1473 
1474     return success();
1475   }
1476 };
1477 
1478 class ConvertGatherNdOpDynamic : public OpRewritePattern<TF::GatherNdOp> {
1479   using OpRewritePattern<TF::GatherNdOp>::OpRewritePattern;
1480   // Converts tf.GatherNdOp to mhlo.DynamicGatherOp.
1481   // Here we leave 'slice_sizes' as an Attr, without defining a new
1482   // DynamicGatherOp, since GatherDimensionNumbers has already provide enough
1483   // information for shape inference and code generation of mhlo::GatherOp. '?'
1484   // will be filled into slice_sizes for dimensions that are dynamic sized.
1485   // TODO(disc): To recover static special case's performance with folding and
1486   // canonicalization.
matchAndRewrite(TF::GatherNdOp op,PatternRewriter & rewriter) const1487   LogicalResult matchAndRewrite(TF::GatherNdOp op,
1488                                 PatternRewriter &rewriter) const override {
1489     Location loc = op.getLoc();
1490     auto params = op.params();
1491     auto params_ty = params.getType().dyn_cast<RankedTensorType>();
1492     auto indices = op.indices();
1493     auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
1494     auto params_rank = params_ty.getRank();
1495     auto indices_rank = indices_ty.getRank();
1496     int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1);
1497     if (!params_ty || !indices_ty) return failure();
1498     // the last dim of indices of GatherNdOp must be fixed shaped
1499     if (num_index_dims == ShapedType::kDynamicSize) return failure();
1500 
1501     SmallVector<int64_t, 4> slice_sizes;
1502     slice_sizes.reserve(params_rank);
1503     for (int64_t i = 0; i < params_rank; ++i) {
1504       if (i < num_index_dims) {
1505         slice_sizes.push_back(1);
1506       } else {
1507         // potentially dynamic
1508         int64_t dim_size = params_ty.getDimSize(i);
1509         slice_sizes.push_back(dim_size);
1510       }
1511     }
1512     SmallVector<Value, 4> slice_sizes_vals;
1513     Value slice_sizes_value = nullptr;
1514     for (int64_t i = 0; i < params_rank; ++i) {
1515       if (i < num_index_dims) {
1516         slice_sizes_vals.push_back(rewriter.create<arith::ConstantOp>(
1517             loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1)));
1518       } else {
1519         int64_t dim_size = params_ty.getDimSize(i);
1520         if (dim_size != ShapedType::kDynamicSize) {
1521           slice_sizes_vals.push_back(rewriter.create<arith::ConstantOp>(
1522               loc,
1523               rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size)));
1524         } else {
1525           slice_sizes_vals.push_back(rewriter.create<arith::IndexCastOp>(
1526               loc, indices_ty.getElementType(),
1527               rewriter.create<tensor::DimOp>(loc, params, i)));
1528         }
1529       }
1530     }
1531     slice_sizes_value =
1532         rewriter.create<tensor::FromElementsOp>(loc, slice_sizes_vals);
1533 
1534     // collapsed_slice_dims
1535     SmallVector<int64_t, 4> collapsed_slice_dims;
1536     collapsed_slice_dims.reserve(num_index_dims);
1537     for (int64_t i = 0; i < num_index_dims; ++i) {
1538       collapsed_slice_dims.push_back(i);
1539     }
1540     // offset_dims
1541     SmallVector<int64_t, 4> offset_dims;
1542     offset_dims.reserve(params_rank - num_index_dims);
1543     for (int64_t i = num_index_dims; i < params_rank; i++) {
1544       offset_dims.push_back(i + indices_rank - 1 - num_index_dims);
1545     }
1546     // start_index_map
1547     SmallVector<int64_t, 4> start_index_map;
1548     offset_dims.reserve(num_index_dims);
1549     for (int64_t i = 0; i < num_index_dims; i++) {
1550       start_index_map.push_back(i);
1551     }
1552     // index_vector_dim
1553     int64_t index_vector_dim = indices_rank - 1;
1554 
1555     auto dims_attr = GatherDimensionNumbersAttr::get(
1556         rewriter.getContext(), offset_dims, collapsed_slice_dims,
1557         start_index_map, index_vector_dim);
1558     // TODO(disc): Remove this if-statement once fold and canonicalization is
1559     // implemented.
1560     if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) {
1561       rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
1562           op, op.getType(), op.params(), op.indices(), dims_attr,
1563           GetI64ElementsAttr(slice_sizes, &rewriter));
1564     } else {
1565       rewriter.replaceOpWithNewOp<mhlo::DynamicGatherOp>(
1566           op, op.getType(), op.params(), op.indices(), slice_sizes_value,
1567           dims_attr);
1568     }
1569     return success();
1570   }
1571 };
1572 
1573 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
1574 // division can result in strange behavior.
1575 //
1576 //      floordiv = cast(floordiv(cast(left), cast(right))))
1577 //
1578 //   %left_cast = cast(%left)
1579 //   %right_cast = cast(%right)
1580 //   %div = div(%left, %left)
1581 //   %floored = floor(%div)
1582 //   %floored_cast = cast(%floored)
1583 //
1584 // Required to manually specify the intermediate types.
1585 class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
1586  public:
1587   using OpRewritePattern::OpRewritePattern;
1588 
matchAndRewrite(TF::FloorDivOp op,PatternRewriter & rewriter) const1589   LogicalResult matchAndRewrite(TF::FloorDivOp op,
1590                                 PatternRewriter &rewriter) const override {
1591     auto l = op.x();
1592     auto r = op.y();
1593     auto element_type = getElementTypeOrSelf(l.getType());
1594     if (!element_type.isBF16()) return failure();
1595 
1596     auto out_type = op.z().getType().cast<TensorType>();
1597 
1598     l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
1599     r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
1600 
1601     auto intermediate = rewriter.create<TF::FloorDivOp>(
1602         op.getLoc(),
1603         ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
1604         r);
1605 
1606     auto floor_op =
1607         rewriter.create<ConvertOp>(op.getLoc(), out_type, intermediate);
1608     rewriter.replaceOp(op, floor_op.getResult());
1609     return success();
1610   }
1611 };
1612 
1613 class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
1614  public:
1615   using OpRewritePattern::OpRewritePattern;
1616 
matchAndRewrite(TF::BroadcastToOp op,PatternRewriter & rewriter) const1617   LogicalResult matchAndRewrite(TF::BroadcastToOp op,
1618                                 PatternRewriter &rewriter) const override {
1619     auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1620     auto output_type = op.output().getType();
1621     if (!input_type) {
1622       return rewriter.notifyMatchFailure(op, "requires ranked input shape");
1623     }
1624     llvm::SmallVector<int64_t, 4> broadcast_dimensions;
1625     if (input_type.getRank() > 0) {
1626       auto ranked_output_type = output_type.dyn_cast<RankedTensorType>();
1627       if (!ranked_output_type) {
1628         return rewriter.notifyMatchFailure(op, "requires ranked output shape");
1629       }
1630       auto rank_diff = ranked_output_type.getRank() - input_type.getRank();
1631       // The tf.BroadcastTo op performs "right-aligned" numpy-style
1632       // broadcasting.
1633       broadcast_dimensions = llvm::to_vector<4>(
1634           llvm::seq<int64_t>(rank_diff, ranked_output_type.getRank()));
1635     }
1636     rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
1637         op, output_type, op.input(), op.shape(),
1638         rewriter.getI64TensorAttr(broadcast_dimensions));
1639     return success();
1640   }
1641 };
1642 
1643 /// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis
1644 /// have to be a constant.
1645 class ConvertRollOp : public OpRewritePattern<TF::RollOp> {
1646  public:
1647   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::RollOp op,PatternRewriter & rewriter) const1648   LogicalResult matchAndRewrite(TF::RollOp op,
1649                                 PatternRewriter &rewriter) const override {
1650     auto shift_ty = op.shift().getType().dyn_cast<RankedTensorType>();
1651     if (!shift_ty || shift_ty.getRank() != 0) {
1652       return rewriter.notifyMatchFailure(
1653           op, "require the type of shift to be 0D tensor");
1654     }
1655 
1656     APInt val;
1657     if (!matchPattern(op.axis(), m_ConstantInt(&val))) {
1658       return rewriter.notifyMatchFailure(op, "require axis to be constant");
1659     }
1660     int axis = val.getSExtValue();
1661 
1662     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1663     if (!input_ty || !input_ty.hasStaticShape()) {
1664       return rewriter.notifyMatchFailure(
1665           op, "require the type of input to have static shapes");
1666     }
1667     ArrayRef<int64_t> input_shape = input_ty.getShape();
1668     int input_rank = input_ty.getRank();
1669     if (axis < 0) axis += input_rank;
1670 
1671     // Adjust large offsets into [0, axis_size). This also makes negative
1672     // offsets positive.
1673     // offset = ((offset % axis_size) + axis_size) % axis_size
1674     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1675     Value offset = op.shift();
1676     auto axis_size = b.create<mhlo::ConstantOp>(b.getIntegerAttr(
1677         getElementTypeOrSelf(offset.getType()), input_shape[axis]));
1678     offset = b.create<RemOp>(
1679         b.create<AddOp>(b.create<RemOp>(offset, axis_size), axis_size),
1680         axis_size);
1681 
1682     // Stack two copies of the dimension, then slice from the calculated
1683     // offset. This also works if shift is not constant.
1684     // DynamicSliceOp requires the sizes being integer, and we can get the
1685     // information from input shape.
1686     auto concat = b.create<ConcatenateOp>(ValueRange{op.input(), op.input()},
1687                                           b.getI64IntegerAttr(axis));
1688     Value zero = b.create<mhlo::ConstantOp>(
1689         b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0));
1690     SmallVector<Value> slice_begin_indices(input_rank, zero);
1691     slice_begin_indices[axis] = b.create<SubtractOp>(axis_size, offset);
1692     rewriter.replaceOpWithNewOp<DynamicSliceOp>(
1693         op, input_ty, concat, slice_begin_indices,
1694         rewriter.getI64TensorAttr(input_shape));
1695     return success();
1696   }
1697 };
1698 
1699 /// Converts a TF::LeakyReluOp to HLO.
1700 /// LeakyRelu(x) = alpha * x if x < 0 else x.
1701 class ConvertLeakyReluOp : public OpRewritePattern<TF::LeakyReluOp> {
1702  public:
1703   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::LeakyReluOp op,PatternRewriter & rewriter) const1704   LogicalResult matchAndRewrite(TF::LeakyReluOp op,
1705                                 PatternRewriter &rewriter) const override {
1706     Location loc = op.getLoc();
1707     Value features = op.features();
1708 
1709     // Use ConstantLike for `alpha` to match the shape of feature.
1710     auto alphaVal = chlo::getConstantLike(
1711         rewriter, loc, op.alpha().convertToFloat(), features);
1712     Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);
1713 
1714     Value leakyActivationVal =
1715         rewriter.create<mhlo::MulOp>(loc, features, alphaVal);
1716 
1717     Value compareGtZero = rewriter.create<mhlo::CompareOp>(
1718         loc, features, zeroVal, ComparisonDirection::GT);
1719 
1720     rewriter.replaceOpWithNewOp<SelectOp>(op, compareGtZero, features,
1721                                           leakyActivationVal);
1722     return success();
1723   }
1724 };
1725 
1726 /// Converts a TF::LeakyReluGradOp to HLO.
1727 /// LeakyReluGrad(gradient, inputs) = gradient if input > 0
1728 /// else alpha * gradient.
1729 class ConvertLeakyReluGradOp : public OpRewritePattern<TF::LeakyReluGradOp> {
1730  public:
1731   using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::LeakyReluGradOp op,PatternRewriter & rewriter) const1732   LogicalResult matchAndRewrite(TF::LeakyReluGradOp op,
1733                                 PatternRewriter &rewriter) const override {
1734     Location loc = op.getLoc();
1735     Value gradients = op.gradients();
1736     Value features = op.features();
1737     auto featureType = features.getType();
1738 
1739     // Use ConstantLike for `alpha` to match the shape of feature.
1740     auto alphaVal = chlo::getConstantLike(
1741         rewriter, loc, op.alpha().convertToFloat(), features);
1742     Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);
1743 
1744     Value leakyGradientVal =
1745         rewriter.create<mhlo::MulOp>(loc, gradients, alphaVal);
1746 
1747     Value compareGtZero = rewriter.create<mhlo::CompareOp>(
1748         loc, features, zeroVal, ComparisonDirection::GT);
1749 
1750     rewriter.replaceOpWithNewOp<SelectOp>(op, featureType, compareGtZero,
1751                                           gradients, leakyGradientVal);
1752     return success();
1753   }
1754 };
1755 
1756 // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix.
1757 // For a Rank-2 input, it creates the following ops:
1758 //   %1 = "mhlo.iota"() {iota_dimension = 0 : i64}
1759 //   %2 = "mhlo.iota"() {iota_dimension = 1 : i64}
1760 //   %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"}
1761 //   %4 = mhlo.constant dense<0.000000e+00> : tensor<f32>
1762 //   %5 = "mhlo.broadcast"(%4)
1763 //   %6 = "mhlo.select"(%3, %input, %5)
1764 //   %7 = "mhlo.reduce"(%6, %4) ({
1765 //   ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
1766 //     %9 = mhlo.add %arg1, %arg2 : tensor<f32>
1767 //     "mhlo.return"(%9) : (tensor<f32>) -> ()
1768 //   }) {dimensions = dense<0> : tensor<1xi64>}
1769 //
1770 // If the input's rank N is greater than 2, we will reshape it to R2 first and
1771 // create the above ops, then reshape it back to rank N/2.
1772 class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
1773  public:
1774   using OpRewritePattern::OpRewritePattern;
1775 
matchAndRewrite(TF::DiagPartOp op,PatternRewriter & rewriter) const1776   LogicalResult matchAndRewrite(TF::DiagPartOp op,
1777                                 PatternRewriter &rewriter) const override {
1778     auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1779     if (!input_type || !input_type.hasStaticShape()) return failure();
1780     int64_t num_dims = input_type.getRank();
1781     if (num_dims < 2 || num_dims % 2 != 0) return failure();
1782     const int64_t out_dims = num_dims / 2;
1783 
1784     int64_t new_size = 1;
1785     llvm::SmallVector<int64_t, 4> new_dims;
1786     for (int i = 0; i < out_dims; i++) {
1787       if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims))
1788         return op.emitOpError("invalid dimensions size");
1789       new_size *= input_type.getDimSize(i);
1790       new_dims.push_back(input_type.getDimSize(i));
1791     }
1792     Value reshaped_input = rewriter.create<mhlo::ReshapeOp>(
1793         op.getLoc(),
1794         RankedTensorType::get({new_size, new_size},
1795                               input_type.getElementType()),
1796         op.input());
1797     auto iota_type = RankedTensorType::get({new_size, new_size},
1798                                            rewriter.getIntegerType(32));
1799     auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
1800                                          rewriter.getI64IntegerAttr(0));
1801     auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
1802                                          rewriter.getI64IntegerAttr(1));
1803     Value compare = rewriter.create<CompareOp>(op.getLoc(), iota0, iota1,
1804                                                ComparisonDirection::EQ);
1805     Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(),
1806                                       0, &rewriter);
1807     Value zero_matrix = rewriter.create<BroadcastOp>(
1808         op.getLoc(), reshaped_input.getType(), zero,
1809         GetI64ElementsAttr({new_size, new_size}, &rewriter));
1810     Value masked =
1811         rewriter.create<SelectOp>(op.getLoc(), reshaped_input.getType(),
1812                                   compare, reshaped_input, zero_matrix);
1813     auto reduce = rewriter.create<ReduceOp>(op.getLoc(), masked, zero,
1814                                             GetI64ElementsAttr({0}, &rewriter));
1815     assert(!input_type.getElementType().isInteger(1) &&
1816            "data type should not be i1");
1817     BuildReduceBody<AddOp>(input_type.getElementType(), &reduce.body(),
1818                            &rewriter);
1819     rewriter.replaceOpWithNewOp<ReshapeOp>(
1820         op, RankedTensorType::get(new_dims, input_type.getElementType()),
1821         reduce.getResult(0));
1822     return success();
1823   }
1824 };
1825 
1826 // Converts TensorFlow MatrixDiagPartOp to HLO ops.
1827 class ConvertMatrixDiagPartV3Op
1828     : public OpRewritePattern<TF::MatrixDiagPartV3Op> {
1829   using Shape = llvm::SmallVector<int64_t, 4>;
1830 
1831   // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s)
1832   // with k. This can be either a single value (for a single diagonal) or a
1833   // tuple of two values (starting and ending diagonal, for a band).
ExtractK(TF::MatrixDiagPartV3Op op,int64_t (* k)[2]) const1834   LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const {
1835     DenseIntElementsAttr kattr;
1836     if (!matchPattern(op.k(), m_Constant(&kattr))) {
1837       return failure();
1838     }
1839     DenseIntElementsAttr::iterator it = kattr.begin();
1840     (*k)[0] = (*it).getSExtValue();
1841     it++;
1842     if (it == kattr.end()) {
1843       // Handle input like e.g. "k = 5", in which case we extract a single
1844       // diagonal.
1845       (*k)[1] = (*k)[0];
1846     } else {
1847       // Handle input like e.g. "k = [-1, 1]", in which case we extract a
1848       // band (multiple diagonals).
1849       (*k)[1] = (*it).getSExtValue();
1850     }
1851     return success();
1852   }
1853 
1854   // Utility method for broadcasting integer constants to a given shape.
BroadcastConstant(Location loc,Shape shape,int32_t constant,int int_size,PatternRewriter & rewriter) const1855   BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant,
1856                                 int int_size, PatternRewriter &rewriter) const {
1857     return rewriter.create<BroadcastOp>(
1858         loc, RankedTensorType::get(shape, rewriter.getIntegerType(int_size)),
1859         GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant,
1860                              &rewriter),
1861         GetI64ElementsAttr(shape, &rewriter));
1862   }
1863 
1864  public:
1865   using OpRewritePattern::OpRewritePattern;
1866 
matchAndRewrite(TF::MatrixDiagPartV3Op op,PatternRewriter & rewriter) const1867   LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op,
1868                                 PatternRewriter &rewriter) const override {
1869     Location loc = op.getLoc();
1870     ShapedType input_type = op.input().getType().dyn_cast<ShapedType>();
1871 
1872     // Align is a string specifying how superdiagonals and subdiagonals should
1873     // be aligned/padded for diagonals that are shorter than max_diag_len. The
1874     // format is "{super}_{sub}", with {super} the superdiagonal alignment and
1875     // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the
1876     // left, "RIGHT" means rows will be padded ot the right.  The default is
1877     // "RIGHT_LEFT".
1878     StringRef align = op->getAttrOfType<StringAttr>("align").getValue();
1879     enum Alignment { kLeft, kRight };
1880 
1881     // default is RIGHT_LEFT
1882     Alignment superdiagonal_align = kRight;
1883     Alignment subdiagonal_align = kLeft;
1884 
1885     if (align == "RIGHT_LEFT") {
1886       superdiagonal_align = kRight;
1887       subdiagonal_align = kLeft;
1888     } else if (align == "RIGHT_RIGHT") {
1889       superdiagonal_align = kRight;
1890       subdiagonal_align = kRight;
1891     } else if (align == "LEFT_RIGHT") {
1892       superdiagonal_align = kLeft;
1893       subdiagonal_align = kRight;
1894     } else if (align == "LEFT_LEFT") {
1895       superdiagonal_align = kLeft;
1896       subdiagonal_align = kLeft;
1897     } else {
1898       return failure();  // unsupported alignment
1899     }
1900 
1901     // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and
1902     // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L].
1903     if (!input_type || !input_type.hasStaticShape()) return failure();
1904     int64_t num_dims = input_type.getRank();
1905     if (num_dims < 2) return failure();
1906     int64_t rows = input_type.getDimSize(num_dims - 2);  // rows
1907     int64_t cols = input_type.getDimSize(num_dims - 1);  // cols
1908 
1909     // We extract the diagonals from k[0] up to and including k[1].
1910     // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract
1911     // the main diagonal). It's negative for subdiagonals (under and to the left
1912     // of the main diagonal) and positive for superdiagonals (above and to the
1913     // right of the main diagonal).
1914     int64_t k[2];
1915     if (failed(ExtractK(op, &k))) return failure();
1916     int num_diags = k[1] - k[0] + 1;
1917 
1918     // Shifting diagonals away from the main diagonal might shorten them. This
1919     // is the longest diagonal we will see. We make this the last dimension of
1920     // the output shape.
1921     int64_t max_diag_len =
1922         std::min(rows + std::min(k[1], static_cast<int64_t>(0)),
1923                  cols + std::min(-k[0], static_cast<int64_t>(0)));
1924 
1925     // The first dimension is the index vector dimension we'll use for gather.
1926     // It's 1 here, but will be 2 once we glue x and y together.
1927     Shape indices_shape({1, num_diags, max_diag_len});
1928 
1929     RankedTensorType iota_type =
1930         RankedTensorType::get(indices_shape, rewriter.getIntegerType(32));
1931     Value iotaM =
1932         rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(1));
1933     Value iotaN =
1934         rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(2));
1935 
1936     // Boradcasted constants, of the same shape as iotaM and iotaN.
1937     Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter);
1938     Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter);
1939     Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter);
1940     Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter);
1941     Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter);
1942     Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter);
1943     Value b_max_diag_len =
1944         BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter);
1945 
1946     // d = k[1] - m
1947     // (A.k.a. the number of the diagonal, depending on m. Note that we
1948     //  subtract m here. This means we start with the superdiagonals and
1949     //  move downwards towards the subdiagonals. So the start indices will
1950     //  be decreasing.)
1951     Value d = rewriter.create<SubtractOp>(loc, b_k1, iotaM);
1952     Value neg_d = rewriter.create<NegOp>(loc, d);
1953 
1954     // diag_len_d = min(rows + min(d, 0), cols - max(d, 0))
1955     // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.)
1956     Value diag_len_d = rewriter.create<MinOp>(
1957         loc,
1958         rewriter.create<AddOp>(loc, b_rows,
1959                                rewriter.create<MinOp>(loc, d, b_zero)),
1960         rewriter.create<SubtractOp>(loc, b_cols,
1961                                     rewriter.create<MaxOp>(loc, d, b_zero)));
1962 
1963     // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise.
1964     Value cmp;
1965     if (subdiagonal_align == kRight && superdiagonal_align == kRight) {
1966       cmp = b_true;
1967     } else if (superdiagonal_align == kRight) {
1968       // offset = d>=0 ? max_diag_len - diag_len_d : 0
1969       cmp = rewriter.create<TF::GreaterEqualOp>(loc, d, b_zero);
1970     } else if (subdiagonal_align == kRight) {
1971       // offset = d<=0 ? max_diag_len - diag_len_d : 0
1972       cmp = rewriter.create<TF::LessEqualOp>(loc, d, b_zero);
1973     } else {
1974       // offset = 0
1975       cmp = b_false;
1976     }
1977 
1978     // This offset shifts the diagonals to the "left" or "right", depending
1979     // on alignment.
1980     Value offset = rewriter.create<SelectOp>(
1981         loc, b_zero.getType(), cmp,
1982         rewriter.create<SubtractOp>(loc, b_max_diag_len, diag_len_d), b_zero);
1983 
1984     // x = max(d, 0) - offset
1985     // y = max(-d, 0) - offset
1986     Value x = rewriter.create<SubtractOp>(
1987         loc, rewriter.create<MaxOp>(loc, d, b_zero), offset);
1988     Value y = rewriter.create<SubtractOp>(
1989         loc, rewriter.create<MaxOp>(loc, neg_d, b_zero), offset);
1990 
1991     Value n_plus_x = rewriter.create<AddOp>(loc, iotaN, x);
1992     Value n_plus_y = rewriter.create<AddOp>(loc, iotaN, y);
1993 
1994     // GatherOp is happy about letting us index out of bounds values, but those
1995     // values will be undefined. So we mask them later. Set up the boolean
1996     // expression that tells us which entries, in the output shape, are out of
1997     // bounds and thus become the padding_value.
1998     Value x_in_bounds = rewriter.create<AndOp>(
1999         loc,
2000         rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_x,
2001                                             b_zero),
2002         rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_x, b_cols));
2003     Value y_in_bounds = rewriter.create<AndOp>(
2004         loc,
2005         rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_y,
2006                                             b_zero),
2007         rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_y, b_rows));
2008     Value in_bounds = rewriter.create<ReshapeOp>(
2009         loc,
2010         RankedTensorType::get(Shape({num_diags, max_diag_len}),
2011                               rewriter.getIntegerType(1)),
2012         rewriter.create<AndOp>(loc, x_in_bounds, y_in_bounds));
2013 
2014     // Now combine x and y into the index data structure needed for gather.
2015     Shape concat_shape({2, num_diags, max_diag_len});
2016     Value start_indices = rewriter.create<ConcatenateOp>(
2017         loc, RankedTensorType::get(concat_shape, rewriter.getIntegerType(32)),
2018         mlir::ValueRange({n_plus_y, n_plus_x}),
2019         mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0));
2020 
2021     // Shape of the final output. (Except for dimension folding in the
2022     // single diagonal case.)
2023     Shape output_shape;
2024     for (int i = 0; i < num_dims - 2; i++) {
2025       output_shape.push_back(input_type.getDimSize(i));
2026     }
2027     output_shape.push_back(num_diags);
2028     output_shape.push_back(max_diag_len);
2029 
2030     // A slice is the shape of what GatherOp copies per lookup. So the last
2031     // two dimensions (M, N in the matrix-diag-part docs) are where we go
2032     // through entry by entry.
2033     ArrayRef<int64_t> input_shape = input_type.getShape();
2034     Shape slice_sizes(input_shape.begin(), input_shape.end());
2035     int slice_dimensions = slice_sizes.size();
2036     slice_sizes[slice_dimensions - 2] = 1;
2037     slice_sizes[slice_dimensions - 1] = 1;
2038 
2039     // Dimensions of the input we won't see in the output (M and N).
2040     SmallVector<int64_t, 2> collapsed_dims(
2041         {slice_dimensions - 2, slice_dimensions - 1});
2042 
2043     // Which dimensions (in the input) the two offset "columns" map to.
2044     SmallVector<int64_t, 2> start_index_map({num_dims - 2, num_dims - 1});
2045 
2046     // Gather the diagonal entries.
2047     // TODO(kramm): For a single diagonal, this might be slower than the
2048     //              mask + sum approach. Special-case num_diags==1?
2049     auto dims_attr = GatherDimensionNumbersAttr::get(
2050         rewriter.getContext(),
2051         /*offset_dims=*/llvm::to_vector<4>(llvm::seq<int64_t>(0, num_dims - 2)),
2052         /*collapsed_slice_dims=*/collapsed_dims, start_index_map,
2053         /*index_vector_dim=*/0);
2054     Value gather = rewriter.create<mhlo::GatherOp>(
2055         loc, op.input(), start_indices, dims_attr,
2056         GetI64ElementsAttr(slice_sizes, &rewriter));
2057 
2058     // We now need to broadcast the "in_bounds" boolean expression, as well as
2059     // the padding value, to do the final select.
2060     Shape broadcast_bounds;
2061     for (int i = 0; i < output_shape.size() - 2; i++) {
2062       broadcast_bounds.push_back(output_shape[i]);
2063     }
2064     Value b_in_bounds = rewriter.create<BroadcastOp>(
2065         loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)),
2066         in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter));
2067     Value b_padding = rewriter.create<BroadcastOp>(
2068         loc, op.padding_value(), GetI64ElementsAttr(output_shape, &rewriter));
2069 
2070     // Replace all out-of-bounds values in the result with padding_value.
2071     Value result =
2072         rewriter.create<SelectOp>(loc, b_in_bounds, gather, b_padding);
2073 
2074     if (num_diags == 1) {
2075       // matrix_diag_part folds away the 1-sized band dimension if we only
2076       // extract a single diagonal.
2077       result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
2078     }
2079 
2080     rewriter.replaceOp(op, result);
2081     return success();
2082   }
2083 };
2084 
2085 // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
2086 // depending on arity of the op.
2087 class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
2088  public:
2089   using OpRewritePattern::OpRewritePattern;
2090 
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const2091   LogicalResult matchAndRewrite(TF::EinsumOp op,
2092                                 PatternRewriter &rewriter) const override {
2093     StringAttr equation = op->getAttrOfType<StringAttr>("equation");
2094     if (op.N() == 1) {
2095       rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
2096           op, op.getType(), *op.inputs().begin(), equation);
2097     } else if (op.N() == 2) {
2098       ValueRange inputs = op.inputs();
2099       rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
2100                                             inputs[1], equation);
2101     } else {
2102       // TensorFlow EinsumOp verifies that the number of operands are at most
2103       // two.
2104       return failure();
2105     }
2106     return success();
2107   }
2108 };
2109 
2110 // Bypasses IdentityN op.
2111 class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
2112  public:
2113   using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
matchAndRewrite(TF::IdentityNOp op,PatternRewriter & rewriter) const2114   LogicalResult matchAndRewrite(TF::IdentityNOp op,
2115                                 PatternRewriter &rewriter) const override {
2116     rewriter.replaceOp(op, op.getOperands());
2117     return success();
2118   }
2119 };
2120 
2121 template <typename OpTy>
2122 class ConvertFFTOp : public OpRewritePattern<OpTy> {
2123  public:
2124   using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2125   LogicalResult matchAndRewrite(OpTy op,
2126                                 PatternRewriter &rewriter) const override {
2127     auto input_ty = op.input().getType().template cast<ShapedType>();
2128     if (!input_ty.hasRank()) {
2129       return failure();
2130     }
2131     auto input_shape = input_ty.getShape();
2132     DenseIntElementsAttr fft_length_attr;
2133     if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) {
2134       return failure();
2135     }
2136     int64_t fft_length;
2137     if (fft_length_attr.getNumElements() != 0) {
2138       fft_length = fft_length_attr.getValues<IntegerAttr>()[0].getInt();
2139     } else {
2140       return failure();
2141     }
2142 
2143     int64_t expected_dim = fft_length;
2144     std::string fft_string = "RFFT";
2145     if (typeid(OpTy) == typeid(TF::IRFFTOp)) {
2146       expected_dim = fft_length / 2 + 1;
2147       fft_string = "IRFFT";
2148     }
2149     Location loc = op.getLoc();
2150 
2151     // The inner-most dim cannot be dynamic.
2152     if (input_ty.isDynamicDim(input_shape.size() - 1)) {
2153       return failure();
2154     }
2155 
2156     auto expected_shape = llvm::to_vector<4>(input_shape.drop_back());
2157     expected_shape.push_back(expected_dim);
2158 
2159     // Zero pad or truncate the last axis
2160     Value reshaped = op.input();
2161     SmallVector<int64_t, 4> begin_indices(input_shape.size(), 0);
2162     SmallVector<int64_t, 4> strides(input_shape.size(), 1);
2163 
2164     // Last dim larger than expected_dim, slice the input
2165     if (input_shape.back() > expected_dim) {
2166       reshaped = rewriter.create<SliceOp>(
2167           op.getLoc(),
2168           RankedTensorType::get(expected_shape, input_ty.getElementType()),
2169           op.input(), GetI64ElementsAttr(begin_indices, &rewriter),
2170           GetI64ElementsAttr(expected_shape, &rewriter),
2171           GetI64ElementsAttr(strides, &rewriter));
2172 
2173       // Last dim smaller than expected_dim, zero-pad the input
2174     } else if (input_ty.getShape().back() < expected_dim) {
2175       SmallVector<int64_t, 4> no_padding(input_shape.size(), 0);
2176       SmallVector<int64_t, 4> padding(input_shape.size() - 1, 0);
2177       padding.push_back(expected_dim - input_shape.back());
2178       Value zero =
2179           GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter);
2180       reshaped = rewriter.create<PadOp>(
2181           loc, RankedTensorType::get(expected_shape, input_ty.getElementType()),
2182           op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter),
2183           GetI64ElementsAttr(padding, &rewriter),
2184           GetI64ElementsAttr(no_padding, &rewriter));
2185     }
2186 
2187     rewriter.replaceOpWithNewOp<FftOp>(
2188         op, op.getType(), reshaped,
2189         FftTypeAttr::get(rewriter.getContext(),
2190                          symbolizeFftType(fft_string).getValue()),
2191         rewriter.getI64TensorAttr(fft_length));
2192     return success();
2193   }
2194 };
2195 
2196 using ConvertRFFTOp = ConvertFFTOp<TF::RFFTOp>;
2197 using ConvertIRFFTOp = ConvertFFTOp<TF::IRFFTOp>;
2198 
2199 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
2200 // BatchNormGradOp for training and a sequence of binary ops for inference.
2201 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
2202 template <typename FusedBatchNormGradOpT>
2203 class ConvertFusedBatchNormGradBase
2204     : public OpRewritePattern<FusedBatchNormGradOpT> {
2205  public:
2206   using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
2207 
matchAndRewrite(FusedBatchNormGradOpT op,PatternRewriter & rewriter) const2208   LogicalResult matchAndRewrite(FusedBatchNormGradOpT op,
2209                                 PatternRewriter &rewriter) const override {
2210     Location loc = op.getLoc();
2211     Value grad = op.y_backprop();
2212     Value act = op.x();
2213     Value scale = op.scale();
2214     Value mean = op.reserve_space_1();
2215     Value var = op.reserve_space_2();
2216 
2217     // TODO(b/141785544): Update this to not require static shapes.
2218     // activation shape needs to be static to convert negative indices in
2219     // TensorFlow to absolute indices required by HLO.
2220     RankedTensorType act_type =
2221         act.getType().template dyn_cast<RankedTensorType>();
2222     if (!act_type) return failure();
2223     Type act_ele_type = act_type.getElementType();
2224     // To support mixed precision, the statistics type, which maybe more
2225     // precise than the input types, are used for this op.
2226     Type kernel_type =
2227         scale.getType().template cast<TensorType>().getElementType();
2228     grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
2229     act = rewriter.create<ConvertOp>(loc, act, kernel_type);
2230 
2231     tensorflow::TensorFormat data_format;
2232     if (!FormatFromString(op.data_format().str(), &data_format))
2233       return op.emitOpError("invalid data format");
2234 
2235     auto feature_dim_attr = getFeatureDimensionAttr(rewriter, data_format, act);
2236     auto feature_dim = feature_dim_attr.getValue().getSExtValue();
2237 
2238     // Gets the result values.
2239     Value x_backprop, scale_backprop, offset_backprop;
2240     if (op.is_training()) {  // training
2241       // TODO(b/145536565): handle GPU logic separately.
2242       // Infers the output type with the converted `act`.
2243       Type feature_type = RankedTensorType::get(
2244           {GetDimSize(act_type, feature_dim)}, kernel_type);
2245 
2246       SmallVector<Type, 3> operand_types = {act.getType(), feature_type,
2247                                             feature_type};
2248       auto training_op = rewriter.create<BatchNormGradOp>(
2249           loc, operand_types, act, scale, mean, var, grad, op.epsilon(),
2250           feature_dim);
2251 
2252       x_backprop = training_op.getResult(0);
2253 
2254       scale_backprop = training_op.getResult(1);
2255 
2256       offset_backprop = training_op.getResult(2);
2257     } else {  // inference
2258       SmallVector<int64_t, 4> non_feature_dims;
2259       for (int64_t i = 0; i < act_type.getRank(); ++i) {
2260         if (i == feature_dim) continue;
2261         non_feature_dims.push_back(i);
2262       }
2263       auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
2264       auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2265 
2266       // scratch1 = rsqrt(var + epsilon)
2267       RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
2268       auto epsilon = rewriter.create<ConstantOp>(
2269           loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
2270       auto add_op = rewriter.create<chlo::BroadcastAddOp>(
2271           loc, var, epsilon.getResult(), scalar_broadcast_dims);
2272 
2273       Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
2274 
2275       // scratch2 = sum(y_backprop * (x - mean))
2276       auto sub_op = rewriter.create<mhlo::SubtractOp>(
2277           loc, act,
2278           Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
2279       auto weighted_grad = rewriter.create<mhlo::MulOp>(loc, grad, sub_op);
2280       Value scratch2 =
2281           ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
2282 
2283       // x_backprop = y_backprop * (scale * scratch1)
2284       auto scaled_grad =
2285           rewriter.create<mhlo::MulOp>(loc, op.scale(), scratch1);
2286       x_backprop = rewriter.create<mhlo::MulOp>(
2287           loc, grad,
2288           Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
2289                                   rewriter));
2290 
2291       // scale_backprop = scratch2 * scratch1
2292       scale_backprop = rewriter.create<mhlo::MulOp>(loc, scratch1, scratch2);
2293 
2294       // offset_backprop = sum(y_backprop)
2295       offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
2296     }
2297 
2298     x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
2299     Value last_val[2];
2300     if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) {
2301       // It doesn't matter what values we provide for the last 2 results.
2302       last_val[0] = last_val[1] = op.x();
2303     } else {
2304       auto const_val = rewriter.create<ConstantOp>(
2305           op.getLoc(),
2306           DenseElementsAttr::get<float>(
2307               RankedTensorType::get({0}, getElementTypeOrSelf(op.getResult(3))),
2308               0.0));
2309       auto maybe_cast = [&](Value val, Type t) -> Value {
2310         if (val.getType() == t) return val;
2311         return rewriter.create<tensor::CastOp>(op.getLoc(), t, val);
2312       };
2313       last_val[0] = maybe_cast(const_val, op.getResult(3).getType());
2314       last_val[1] = maybe_cast(const_val, op.getResult(4).getType());
2315     }
2316     rewriter.replaceOp(
2317         op, {/*x_backprop=*/x_backprop,
2318              /*scale_backprop=*/scale_backprop,
2319              /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]});
2320     return success();
2321   }
2322 };
2323 
2324 using ConvertFusedBatchNormGradOp =
2325     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
2326 using ConvertFusedBatchNormGradV2Op =
2327     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
2328 using ConvertFusedBatchNormGradV3Op =
2329     ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
2330 
2331 // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
2332 // HLO BatchNormInferenceOp, depending on the value of the 'is_training'
2333 // parameter.
2334 template <typename FusedBatchNormOpT>
2335 class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
2336  public:
2337   using OpRewritePattern<FusedBatchNormOpT>::OpRewritePattern;
2338 
matchAndRewrite(FusedBatchNormOpT op,PatternRewriter & rewriter) const2339   LogicalResult matchAndRewrite(FusedBatchNormOpT op,
2340                                 PatternRewriter &rewriter) const override {
2341     tensorflow::TensorFormat data_format;
2342     if (!FormatFromString(op.data_format().str(), &data_format))
2343       return op.emitOpError("invalid data format");
2344 
2345     auto feature_dim = getFeatureDimensionAttr(rewriter, data_format, op.x());
2346 
2347     auto input_type_tensor = op.x().getType().template cast<TensorType>();
2348     auto input_element_type = input_type_tensor.getElementType();
2349 
2350     auto scale_type_tensor = op.scale().getType().template cast<TensorType>();
2351     auto scale_element_type = scale_type_tensor.getElementType();
2352 
2353     auto mean_type_tensor = op.mean().getType().template cast<TensorType>();
2354     auto mean_element_type = mean_type_tensor.getElementType();
2355     // In the training case, dimensions of input tensors must be static.
2356     if (op.is_training() && (!input_type_tensor.hasStaticShape() ||
2357                              !scale_type_tensor.hasStaticShape() ||
2358                              !mean_type_tensor.hasStaticShape()))
2359       return failure();
2360 
2361     // TODO(b/69928690): Support mixed precision in the XLA batch
2362     // normalization operators. As a workaround, create a new x with the same
2363     // element type as scale (which may be more precise than the input type).
2364     Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
2365                                                             scale_element_type);
2366     TensorType bn_train_input_type_tensor =
2367         bn_train_input.getType().template cast<TensorType>();
2368 
2369     if (op.is_training()) {
2370       // Training case.
2371       auto operand_shape = bn_train_input_type_tensor.getShape();
2372       // The mean and variance are each 1 dimensional arrays the size of the
2373       // feature dimension, with the same element type as the operand (x).
2374       // This shape must be constructed manually because the mean and variance
2375       // inputs are empty in the training case.
2376       Type mean_var_type = RankedTensorType::get(
2377           {operand_shape[feature_dim.getInt()]}, scale_element_type);
2378       // Op result type is a tuple of 3 values: output with same shape as input;
2379       // batch_mean, and batch_var.
2380       SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
2381                                             mean_var_type, mean_var_type};
2382       auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
2383           op.getLoc(), operand_types, bn_train_input, op.scale(), op.offset(),
2384           op.epsilon(), feature_dim.getInt());
2385       // HLO op outputs a tuple of tensors. Extract those results.
2386       Value y_out = bn_train_op.getResult(0);
2387       Value batch_mean = bn_train_op.getResult(1);
2388       Value reserve_space_1 = batch_mean;
2389       Value batch_variance = bn_train_op.getResult(2);
2390 
2391       // Apply Bessel's correction on the variance.
2392       int total_input_size = bn_train_input_type_tensor.getNumElements();
2393       int total_scale_size = scale_type_tensor.getNumElements();
2394       int sample_size = total_input_size / total_scale_size;
2395       int sample_size_minus_one = std::max(1, sample_size - 1);
2396       double factor = static_cast<double>(sample_size) /
2397                       static_cast<double>(sample_size_minus_one);
2398       auto factor_const_op = rewriter.create<mhlo::ConstantOp>(
2399           op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
2400 
2401       Value corrected_variance = rewriter.create<chlo::BroadcastMulOp>(
2402           op.getLoc(), batch_variance.getType(), batch_variance,
2403           factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
2404 
2405       // Convert back to input type to stay aligned with expected output type
2406       // for TF op.
2407       y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), y_out,
2408                                                input_element_type);
2409 
2410       float exponential_avg_factor =
2411           op.exponential_avg_factor().convertToFloat();
2412       if (exponential_avg_factor != 1.0f) {
2413         auto alpha = rewriter.create<mhlo::ConstantOp>(
2414             op.getLoc(), rewriter.getFloatAttr(mean_element_type,
2415                                                1.0f - exponential_avg_factor));
2416         auto beta = rewriter.create<mhlo::ConstantOp>(
2417             op.getLoc(),
2418             rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
2419 
2420         // new_running_mean = alpha * old_mean + beta * batch_mean.
2421         auto alpha_mul_old_mean = rewriter.create<chlo::BroadcastMulOp>(
2422             op.getLoc(), op.mean().getType(), alpha, op.mean(),
2423             /*broadcast_dimensions=*/DenseIntElementsAttr());
2424         auto beta_mul_batch_mean = rewriter.create<chlo::BroadcastMulOp>(
2425             op.getLoc(), batch_mean.getType(), beta, batch_mean,
2426             /*broadcast_dimensions=*/DenseIntElementsAttr());
2427         batch_mean = rewriter.create<chlo::BroadcastAddOp>(
2428             op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
2429             /*broadcast_dimensions=*/DenseIntElementsAttr());
2430 
2431         // new_running_variance = alpha * old_variance + beta * batch_variance.
2432         auto alpha_mul_old_variance = rewriter.create<chlo::BroadcastMulOp>(
2433             op.getLoc(), op.variance().getType(), alpha, op.variance(),
2434             /*broadcast_dimensions=*/DenseIntElementsAttr());
2435         auto beta_mul_batch_variance = rewriter.create<chlo::BroadcastMulOp>(
2436             op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
2437             /*broadcast_dimensions=*/DenseIntElementsAttr());
2438         corrected_variance = rewriter.create<chlo::BroadcastAddOp>(
2439             op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
2440             /*broadcast_dimensions=*/DenseIntElementsAttr());
2441       }
2442 
2443       if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2444         // FusedBatchNormV2 expects 4 outputs.
2445         // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2".
2446         // They are used to pass the per-batch mean and variance to the
2447         // gradiant. Here we maintain the same behavior by setting them to the
2448         // mean and variance calculated by BatchNormTraining.
2449         rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2450                                 /*batch_variance=*/corrected_variance,
2451                                 /*reserve_space_1=*/reserve_space_1,
2452                                 /*reserve_space_2=*/batch_variance});
2453       } else {  // TF::FusedBatchNormV3Op
2454         // For FusedBatchNormV3Op, also create a constant tensor to forward to
2455         // last reserve_space_3 output.
2456         auto reserve_space_3_type =
2457             op.getResult(5).getType().template cast<TensorType>();
2458         int num_elements = reserve_space_3_type.hasStaticShape()
2459                                ? reserve_space_3_type.getNumElements()
2460                                : 0;
2461         auto const_attr_type = RankedTensorType::get(
2462             {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2463         Value dummy_const = rewriter.create<ConstantOp>(
2464             op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2465         if (const_attr_type != reserve_space_3_type)
2466           dummy_const = rewriter.create<tensor::CastOp>(
2467               op.getLoc(), reserve_space_3_type, dummy_const);
2468         rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2469                                 /*batch_variance=*/corrected_variance,
2470                                 /*reserve_space_1=*/reserve_space_1,
2471                                 /*reserve_space_2=*/batch_variance,
2472                                 /*reserve_space_3=*/dummy_const});
2473       }
2474     } else {  // Inference case.
2475       auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
2476           op.getLoc(),
2477           /*result_type=*/bn_train_input_type_tensor, bn_train_input,
2478           op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
2479           feature_dim.getInt());
2480 
2481       // Convert back to input type to stay aligned with expected output type
2482       // for TF op.
2483       auto y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), bn_train_op,
2484                                                     input_element_type);
2485 
2486       // The mean, variance, and reserved space outputs of the batch norm op are
2487       // not used for inference. It doesn't matter what values we provide for
2488       // the last 5 results as long as they are of the same type. Forward
2489       // input mean and variance to output mean, variance, reserved_space_1 and
2490       // reserved_space_2.
2491       if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2492         rewriter.replaceOp(op, {/*y=*/y_out,
2493                                 /*batch_mean=*/op.mean(),
2494                                 /*batch_variance=*/op.variance(),
2495                                 /*reserve_space_1=*/op.mean(),
2496                                 /*reserve_space_2=*/op.variance()});
2497       } else {
2498         // For FusedBatchNormV3Op, also create a constant tensor to forward to
2499         // last reserve_space_3 output.
2500         auto reserve_space_3_type =
2501             op.getResult(5).getType().template cast<TensorType>();
2502         int num_elements = reserve_space_3_type.hasStaticShape()
2503                                ? reserve_space_3_type.getNumElements()
2504                                : 0;
2505         auto const_attr_type = RankedTensorType::get(
2506             {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2507         Value dummy_const = rewriter.create<ConstantOp>(
2508             op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2509         if (const_attr_type != reserve_space_3_type)
2510           dummy_const = rewriter.create<tensor::CastOp>(
2511               op.getLoc(), reserve_space_3_type, dummy_const);
2512         rewriter.replaceOp(op, {/*y=*/y_out,
2513                                 /*batch_mean=*/op.mean(),
2514                                 /*batch_variance=*/op.variance(),
2515                                 /*reserve_space_1=*/op.mean(),
2516                                 /*reserve_space_2=*/op.variance(),
2517                                 /*reserve_space_3=*/dummy_const});
2518       }
2519     }
2520     return success();
2521   }
2522 };
2523 
2524 using ConvertFusedBatchNormV2Op =
2525     ConvertFusedBatchNormBase<TF::FusedBatchNormV2Op>;
2526 using ConvertFusedBatchNormV3Op =
2527     ConvertFusedBatchNormBase<TF::FusedBatchNormV3Op>;
2528 
2529 using PaddingArray = std::vector<std::pair<int64_t, int64_t>>;
2530 
2531 // Returns padding values for ReduceWindow op as a vector of pairs.
2532 //
2533 // Requires padding to be either 'SAME' or 'VALID' and the number of input
2534 // dimensions to be equal to the size of window dimensions and window strides.
2535 template <int num_dims>
GetReduceWindowPaddingAsArray(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2536 static PaddingArray GetReduceWindowPaddingAsArray(
2537     llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2538     ArrayAttr window_strides, StringRef padding, Builder *builder) {
2539   if (padding == "VALID") {
2540     return PaddingArray(num_dims, std::make_pair(0, 0));
2541   }
2542   assert(padding == "SAME");
2543   llvm::SmallVector<int64_t, num_dims> input_shape, window_shape, strides;
2544   input_shape.reserve(input_dims.size());
2545   window_shape.reserve(window_shape.size());
2546   strides.reserve(window_strides.size());
2547 
2548   for (const auto &dim : input_dims) input_shape.push_back(dim);
2549   for (Attribute attr : window_dims)
2550     window_shape.push_back(attr.cast<IntegerAttr>().getInt());
2551   for (Attribute attr : window_strides)
2552     strides.push_back(attr.cast<IntegerAttr>().getInt());
2553 
2554   PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides,
2555                                              ::xla::Padding::kSame);
2556   return paddings;
2557 }
2558 
2559 // Same as GetReduceWindowPaddingAsArray but returns padding as
2560 // DenseIntElementsAttr. Returns empty attribute for `VALID` padding.
2561 template <int num_dims>
GetReduceWindowPaddingAsAttr(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2562 static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
2563     llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2564     ArrayAttr window_strides, StringRef padding, Builder *builder) {
2565   if (padding == "VALID") return {};
2566   assert(padding == "SAME");
2567   PaddingArray paddings = GetReduceWindowPaddingAsArray<num_dims>(
2568       input_dims, window_dims, window_strides, padding, builder);
2569   int64_t rank = paddings.size();
2570   llvm::SmallVector<int64_t, num_dims * 2> flatten_paddings(rank * 2);
2571   for (int i = 0; i < rank; i++) {
2572     flatten_paddings[2 * i] = paddings[i].first;
2573     flatten_paddings[2 * i + 1] = paddings[i].second;
2574   }
2575   return DenseIntElementsAttr::get(
2576       RankedTensorType::get({rank, 2}, builder->getIntegerType(64)),
2577       flatten_paddings);
2578 }
2579 
2580 // Helper function for dividing each entry of `pooled` by the count of its
2581 // corresponding window, i.e., the number of non-padding entries of the window
2582 // which an `AvgPool` operation performed on an `input_shape`-tensor would map
2583 // to this entry, depending on `ksize` and `strides`. This function is used for
2584 // `AvgPool` and `AvgPoolGrad` legalizations.
2585 // `zero` is passed as a parameter because it can be reused from caller level.
2586 // `pooled` must have `RankedTensorType`.
2587 template <typename OpTy, int num_dims>
AvgPoolDivideByCount(Value pooled,const SmallVector<int64_t,num_dims> & input_shape,const SmallVector<int64_t,num_dims> & ksize,const SmallVector<int64_t,num_dims> & strides,OpTy op,Value zero,PatternRewriter & rewriter)2588 Operation *AvgPoolDivideByCount(
2589     Value pooled, const SmallVector<int64_t, num_dims> &input_shape,
2590     const SmallVector<int64_t, num_dims> &ksize,
2591     const SmallVector<int64_t, num_dims> &strides, OpTy op, Value zero,
2592     PatternRewriter &rewriter) {
2593   Location loc = op.getLoc();
2594   RankedTensorType pooled_type =
2595       pooled.getType().template cast<RankedTensorType>();
2596   Type element_type = pooled_type.getElementType();
2597   Operation *result = nullptr;
2598   RankedTensorType orig_input_type =
2599       RankedTensorType::get(input_shape, element_type);
2600 
2601   if (op.padding() == "VALID") {
2602     // All window counts are equal here because we don't have padding
2603     // (each entry of `pooled` corresponds to a window that consists of
2604     //  original input entries only).
2605     int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
2606                                            std::multiplies<int64_t>());
2607     // Divide `pooled` by window counts.
2608     Value divisor =
2609         GetScalarConstOfType(element_type, loc, window_count, &rewriter);
2610     auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2611     result = rewriter.create<chlo::BroadcastDivOp>(
2612         loc, pooled_type, pooled, divisor, scalar_broadcast_dims);
2613   } else {
2614     assert(op.padding() == "SAME");
2615     // For SAME padding, only original entries that contributed to a window
2616     // are counted for the average of this window, not padded entries.
2617 
2618     // Build all-ones tensor of same shape as the original input.
2619     ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
2620     auto all_ones_tensor = rewriter.create<ConstantOp>(loc, splat);
2621 
2622     // Get padding for the input.
2623     DenseIntElementsAttr input_padding_attr =
2624         GetReduceWindowPaddingAsAttr<num_dims>(
2625             input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2626 
2627     // Count the 1's in each window, using the same padding as for the input,
2628     // which gives us the window counts by which `pooled` needs to be divided.
2629     auto divisor = rewriter.create<ReduceWindowOp>(
2630         loc, pooled_type,
2631         /*operand=*/all_ones_tensor,
2632         /*init_value=*/zero,
2633         /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2634         /*window_strides=*/GetI64ElementsAttr(op.strides()),
2635         /*base_dilations=*/DenseIntElementsAttr(),
2636         /*window_dilations=*/DenseIntElementsAttr(),
2637         /*padding=*/input_padding_attr);
2638     BuildReduceBody<AddOp>(element_type, &divisor.body(), &rewriter);
2639 
2640     // Divide `pooled` by window counts.
2641     result = rewriter.create<mhlo::DivOp>(loc, pooled_type, pooled,
2642                                           divisor.getResult(0));
2643   }
2644   return result;
2645 }
2646 
GetAvgPoolInput(TF::AvgPoolOp op)2647 Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); }
GetAvgPoolInput(TF::AvgPool3DOp op)2648 Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); }
2649 
2650 // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
2651 // dimensions with add as the reduction function. The reduction result is
2652 // then divided by the number of elements in the window.
2653 template <typename OpTy, int num_dims>
2654 class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
2655  public:
2656   using OpRewritePattern<OpTy>::OpRewritePattern;
2657 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2658   LogicalResult matchAndRewrite(OpTy op,
2659                                 PatternRewriter &rewriter) const override {
2660     Value input_value = GetAvgPoolInput(op);
2661     auto input_type =
2662         input_value.getType().template dyn_cast<RankedTensorType>();
2663     if (!input_type) return failure();
2664 
2665     // We will do accumulation first; use a larger bitwidth if suitable.
2666     Type input_element_type = input_type.getElementType();
2667     Type sum_element_type = GetSumAccumulationType(input_element_type);
2668     Type result_type;
2669 
2670     // The result type for reduction and division with the proper element type.
2671     if (auto ranked_type = op.getType().template dyn_cast<RankedTensorType>())
2672       result_type =
2673           RankedTensorType::get(ranked_type.getShape(), sum_element_type);
2674     else
2675       result_type = UnrankedTensorType::get(sum_element_type);
2676 
2677     // Convert if we need enlarge the element type's bitwidth.
2678     if (input_element_type != sum_element_type)
2679       input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
2680                                                sum_element_type);
2681 
2682     // Create the ReduceWindow op.
2683     Value init =
2684         GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
2685     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2686         input_type.getShape(), op.ksize(), op.strides(), op.padding(),
2687         &rewriter);
2688     auto reduce = rewriter.create<ReduceWindowOp>(
2689         op.getLoc(), result_type, input_value, init,
2690         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
2691         /*base_dilations=*/DenseIntElementsAttr(),
2692         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2693     BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
2694 
2695     // Count the number of elements in the window. The following calculation
2696     // is only valid for no paddings.
2697     SmallVector<int64_t, num_dims> input_shape(
2698         llvm::to_vector<num_dims>(input_type.getShape()));
2699     SmallVector<int64_t, num_dims> ksize, strides;
2700     GetI64ArrayAttrValues(op.ksize(), &ksize);
2701     GetI64ArrayAttrValues(op.strides(), &strides);
2702 
2703     Operation *result_op = AvgPoolDivideByCount<OpTy, num_dims>(
2704         reduce.getResult(0), input_shape, ksize, strides, op, init, rewriter);
2705 
2706     // Convert back if we enlarged the element type's bitwidth.
2707     Value result = result_op->getOpResult(0);
2708     if (input_element_type != sum_element_type)
2709       result =
2710           rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
2711 
2712     rewriter.replaceOp(op, result);
2713     return success();
2714   }
2715 };
2716 
2717 using ConvertAvgPool2DOp = ConvertAvgPoolOp<TF::AvgPoolOp, /*num_dims=*/4>;
2718 using ConvertAvgPool3DOp = ConvertAvgPoolOp<TF::AvgPool3DOp, /*num_dims=*/5>;
2719 
2720 // `AvgPoolGradOp` is converted to the following operations:
2721 // 1. Divide each entry of the output gradient (the gradient for the previous
2722 //    layer in backpropagation order) by the count of the corresponding window
2723 //    (i.e., the number of non-padding entries of the window which `AvgPool`
2724 //    has mapped to this entry in forward propagation).
2725 // 2. Add appropriate interior and exterior padding for step 3 (see example
2726 //    below).
2727 // 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape
2728 //    as windows) and stride 1 in each dimension. This is implemented as a
2729 //    `ReduceWindowOp` with `AddOp` as body.
2730 //
2731 // Example:
2732 // Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2,
2733 // and SAME padding with 0's. It is defined by
2734 //    f(x) = [ (x_1 + x_2 + x_3) / 3 ]      ( x = (x_1, x_2, x_3, x_4) )
2735 //           [ (x_3 + x_4 + 0)   / 2 ]      (the 0 results from right padding)
2736 // Note that for SAME padding in `AvgPool` the padded entries are not counted
2737 // for the average, this is why the second denominator is 2 and not 3.
2738 // The Jacobian Df is
2739 //    [ 1/3  1/3  1/3  0   ]
2740 //    [ 0    0    1/2  1/2 ]
2741 //
2742 // Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only
2743 // needs the original input shape and not the tensor as argument).
2744 // Let v = [ 4  6 ]^T  be the output gradient (^T = transposed). Then the
2745 // average pool gradient is given by
2746 //    Df^T * v = [ 4/3  4/3  13/3  3 ]^T
2747 // Instead of a matrix-vector-multiplication we can utilize the sparsity and
2748 // structure of Df by using the 3-step approach from above:
2749 // 1. Divide output gradient v by window counts: [ 4/3  6/2 ]^T
2750 // 2. Add appropriate padding: [ 0  0  4/3  0  3  0 ]^T
2751 // 3. Convolve with kernel [ 1  1  1 ]: [ 4/3  4/3  11/3  3 ]^T
2752 //
2753 // Note that the padding in step 2. is chosen in such a way that the subsequent
2754 // convolution produces the gradient. Higher dimensions, different padding, and
2755 // different windows/strides work in a similar way, the main difference is in
2756 // the computation of the paddings in step 2.
2757 //
2758 // For more details on backpropagation for convolution of which `AvgPoolGrad`
2759 // is a special case see `tensorflow/core/kernels/conv_grad_ops.h`.
2760 // `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more
2761 // examples for different cases.
2762 template <typename OpTy, int num_dims>
2763 class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
2764   using DimVector = SmallVector<int64_t, num_dims>;
2765 
2766  public:
2767   using OpRewritePattern<OpTy>::OpRewritePattern;
2768 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2769   LogicalResult matchAndRewrite(OpTy op,
2770                                 PatternRewriter &rewriter) const override {
2771     Location loc = op.getLoc();
2772     tensorflow::TensorFormat data_format;
2773     if (!FormatFromString(op.data_format().str(), &data_format)) {
2774       return op.emitOpError("invalid data format");
2775     }
2776     // `out_grad` is the gradient that was propagated via backpropagation from
2777     // the output layer.
2778     Value out_grad = op.grad();
2779     auto out_grad_type =
2780         out_grad.getType().template dyn_cast<RankedTensorType>();
2781     if (!out_grad_type) {
2782       return failure();
2783     }
2784     Type element_type = out_grad_type.getElementType();
2785     DenseIntElementsAttr orig_input_shape_attr;
2786     if (!matchPattern(op.orig_input_shape(),
2787                       m_Constant(&orig_input_shape_attr))) {
2788       return failure();
2789     }
2790     auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
2791     DimVector orig_input_shape(orig_input_shape_values.begin(),
2792                                orig_input_shape_values.end());
2793     DimVector ksize, strides;
2794     GetI64ArrayAttrValues(op.ksize(), &ksize);
2795     GetI64ArrayAttrValues(op.strides(), &strides);
2796     Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
2797 
2798     auto out_grad_divided = AvgPoolDivideByCount<OpTy, num_dims>(
2799         out_grad, orig_input_shape, ksize, strides, op, zero, rewriter);
2800 
2801     // Get same padding as for original input.
2802     PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
2803         orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2804 
2805     // Add padding around `out_grad_divided` values in such a way that the
2806     // subsequent `ReduceWindowOp` produces the gradient.
2807     DimVector out_grad_shape(
2808         llvm::to_vector<num_dims>(out_grad_type.getShape()));
2809     DimVector low_padding(num_dims, 0);
2810     DimVector high_padding(num_dims, 0);
2811     DimVector interior_padding(num_dims, 0);
2812     constexpr int num_spatial_dims = num_dims - 2;
2813     for (int i = 0; i < num_spatial_dims; ++i) {
2814       int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
2815       int orig_input_shape_padded_in_dim = orig_input_shape[dim] +
2816                                            orig_padding[dim].first +
2817                                            orig_padding[dim].second;
2818       // Set interior padding such that neighboring entries from
2819       // `out_grad_divided` have distance `strides[dim]` from each other in
2820       // every dimension.
2821       interior_padding[dim] = strides[dim] - 1;
2822       // Set exterior padding in the same way as for convolution gradient
2823       // computation.
2824       auto status = ::xla::ConvGradExtractAndVerifyDimension(
2825           /*input_size=*/orig_input_shape_padded_in_dim,
2826           /*filter_size=*/ksize[dim],
2827           /*output_size=*/out_grad_shape[dim],
2828           /*dilation=*/1,
2829           /*stride=*/strides[dim],
2830           /*padding=*/::xla::Padding::kValid);
2831       if (!status.ok()) {
2832         return failure();
2833       }
2834       ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim =
2835           status.ValueOrDie();
2836       // Subtract the original exterior padding since it doesn't contribute to
2837       // the gradient. Note that we save one `PadOp` and some unnecessary kernel
2838       // computations, compared to the `xla::AvgPoolGrad` implementation, by
2839       // subtracting the original exterior padding before `ReduceWindowOp`
2840       // instead of trimming the result of `ReduceWindowOp` (the final result is
2841       // the same because all strides are 1).
2842       low_padding[dim] =
2843           conv_grad_spatial_dim.pad_before - orig_padding[dim].first;
2844       high_padding[dim] =
2845           conv_grad_spatial_dim.pad_after - orig_padding[dim].second;
2846 
2847       // Update `out_grad_shape` to result shape of following `PadOp`.
2848       out_grad_shape[dim] = low_padding[dim] + high_padding[dim] +
2849                             (out_grad_shape[dim] - 1) * strides[dim] + 1;
2850     }
2851     Value reduce_window_input = rewriter.create<PadOp>(
2852         loc, RankedTensorType::get(out_grad_shape, element_type),
2853         /*operand=*/out_grad_divided->getOpResult(0),
2854         /*padding_value=*/zero,
2855         /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter),
2856         /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter),
2857         /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter));
2858 
2859     // Compute result by convolving `reduce_window_input` with an all-ones
2860     // kernel, using `ReduceWindowOp` with `AddOp` body.
2861 
2862     Type sum_element_type = GetSumAccumulationType(element_type);
2863     if (element_type != sum_element_type) {
2864       // Convert to appropriate sum accumulation type to avoid precision loss.
2865       reduce_window_input = rewriter.create<ConvertOp>(loc, reduce_window_input,
2866                                                        sum_element_type);
2867       zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter);
2868     }
2869     auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter);
2870     auto reduce_window_op = rewriter.create<ReduceWindowOp>(
2871         loc, RankedTensorType::get(orig_input_shape, sum_element_type),
2872         /*operand=*/reduce_window_input,
2873         /*init_value=*/zero,
2874         /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2875         /*window_strides=*/ones,
2876         /*base_dilations=*/DenseIntElementsAttr(),
2877         /*window_dilations=*/DenseIntElementsAttr(),
2878         /*padding=*/DenseIntElementsAttr());
2879     BuildReduceBody<AddOp>(sum_element_type, &reduce_window_op.body(),
2880                            &rewriter);
2881     Value result = reduce_window_op.getResult(0);
2882 
2883     if (element_type != sum_element_type) {
2884       // Convert back to original element type.
2885       result = rewriter.create<ConvertOp>(op.getLoc(), result, element_type);
2886     }
2887     rewriter.replaceOp(op, {result});
2888     return success();
2889   }
2890 };
2891 
2892 using ConvertAvgPool2DGradOp =
2893     ConvertAvgPoolGradOp<TF::AvgPoolGradOp, /*num_dims=*/4>;
2894 using ConvertAvgPool3DGradOp =
2895     ConvertAvgPoolGradOp<TF::AvgPool3DGradOp, /*num_dims=*/5>;
2896 
2897 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
2898 // dimensions with max as the reduction function.
2899 //
2900 // Sample result for VALID padding mode:
2901 //
2902 //   %init = arith.constant dense<...> : tensor<i32>
2903 //   %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
2904 //               {window_dimensions = ..., window_strides = ... }
2905 //
2906 template <typename OpTy, int num_dims>
2907 class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
2908  public:
2909   using OpRewritePattern<OpTy>::OpRewritePattern;
2910 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2911   LogicalResult matchAndRewrite(OpTy op,
2912                                 PatternRewriter &rewriter) const override {
2913     Type element_type =
2914         op.input().getType().template cast<TensorType>().getElementType();
2915     if (!element_type.isSignlessIntOrFloat()) return failure();
2916     tensorflow::Padding padding;
2917     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
2918       return failure();
2919     if (padding == tensorflow::Padding::EXPLICIT) {
2920       return failure();
2921     }
2922     Location loc = op.getLoc();
2923     ConstantOp init = GetScalarLimitConstOfType(
2924         element_type, loc, hlo::kInfinityLowest, &rewriter);
2925 
2926     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
2927     if (!input_ty) return failure();
2928     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2929         input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
2930     auto reduce = rewriter.create<ReduceWindowOp>(
2931         loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()),
2932         GetI64ElementsAttr(op.strides()),
2933         /*base_dilations=*/DenseIntElementsAttr(),
2934         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2935     BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
2936 
2937     rewriter.replaceOp(op, reduce.getResult(0));
2938     return success();
2939   }
2940 };
2941 
2942 using ConvertMaxPool2DOp = ConvertMaxPoolOp<TF::MaxPoolOp, /*num_dims=*/4>;
2943 using ConvertMaxPool3DOp = ConvertMaxPoolOp<TF::MaxPool3DOp, /*num_dims=*/5>;
2944 
2945 // Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on
2946 // the condition only.
2947 class ConvertSelectOp : public OpRewritePattern<TF::SelectOp> {
2948  public:
2949   using OpRewritePattern::OpRewritePattern;
2950 
matchAndRewrite(TF::SelectOp op,PatternRewriter & rewriter) const2951   LogicalResult matchAndRewrite(TF::SelectOp op,
2952                                 PatternRewriter &rewriter) const override {
2953     // This lowering only works on ranked types.
2954     auto cond_type = op.condition().getType().dyn_cast<RankedTensorType>();
2955     auto then_type = op.t().getType().dyn_cast<RankedTensorType>();
2956     auto else_type = op.e().getType().dyn_cast<RankedTensorType>();
2957     if (!cond_type || !then_type || !else_type) {
2958       return failure();
2959     }
2960 
2961     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
2962     Value cond_shape = b.createOrFold<shape::ShapeOfOp>(op.condition());
2963     Value then_shape = b.createOrFold<shape::ShapeOfOp>(op.t());
2964     Value else_shape = b.createOrFold<shape::ShapeOfOp>(op.e());
2965 
2966     // First check that the `then` and `else` shapes are the equal.
2967     Value assumption =
2968         b.createOrFold<shape::CstrEqOp>(ValueRange{then_shape, else_shape});
2969     // For a vector cond we also verify that the majormost dim of `then` matches
2970     // the vector size. To do that split off the first dim of `then`.
2971     bool needs_broadcast = cond_type.getRank() == 1 && then_type.getRank() != 1;
2972     Value then_shape_split = then_shape;
2973     if (needs_broadcast) {
2974       Value const_one = b.create<arith::ConstantIndexOp>(1);
2975       Type extent_first = shape::getExtentTensorType(b.getContext(), 1);
2976       Type extent_second =
2977           shape::getExtentTensorType(b.getContext(), then_type.getRank() - 1);
2978       SmallVector<Value, 2> then_split;
2979       b.createOrFold<shape::SplitAtOp>(then_split,
2980                                        TypeRange{extent_first, extent_second},
2981                                        then_shape, const_one);
2982       then_shape_split = then_split[0];
2983     }
2984     // If the condition is not a scalar, check that it matches the other shapes.
2985     if (cond_type.getRank() > 0) {
2986       Value eq_cstr = b.createOrFold<shape::CstrEqOp>(
2987           ValueRange{cond_shape, then_shape_split});
2988       auto witness = shape::WitnessType::get(b.getContext());
2989       assumption = b.createOrFold<shape::AssumingAllOp>(
2990           witness, ValueRange{assumption, eq_cstr});
2991     }
2992     auto result_type = op.getResult().getType().cast<TensorType>();
2993     auto assuming_op =
2994         b.create<shape::AssumingOp>(ArrayRef<Type>{result_type}, assumption);
2995 
2996     OpBuilder::InsertionGuard guard(b);
2997     b.createBlock(&assuming_op.getDoRegion());
2998 
2999     // Broadcast the cond if necessary.
3000     Value cond = op.condition();
3001     if (needs_broadcast) {
3002       Value result_extents = b.create<shape::ToExtentTensorOp>(
3003           GetExtentsTensorTypeFor(result_type), then_shape);
3004       cond = b.create<mhlo::DynamicBroadcastInDimOp>(
3005           RankedTensorType::get(result_type.getShape(), b.getI1Type()), cond,
3006           result_extents, GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b));
3007     }
3008     Value select = b.create<mhlo::SelectOp>(result_type, cond, op.t(), op.e());
3009     b.create<shape::AssumingYieldOp>(select);
3010     rewriter.replaceOp(op, {assuming_op.getResult(0)});
3011     return success();
3012   }
3013 };
3014 
3015 // Converts Sigmoid op to HLO ops computing sigmoid with the following formula:
3016 //
3017 //     sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5)
3018 //
3019 // Sample result with 2-d f16 inputs with B batches of with N elements each.
3020 //
3021 //    // Create an array of 0.5 the shape of the input array.
3022 //    %half = mhlo.constant dense<5.000000e-01> : tensor<f32>
3023 //    %half_array = "mhlo.broadcast"(half)
3024 //                           {broadcast_sizes = dense<2> : tensor<1xi64>}
3025 //                           : (tensor<f32>) -> tensor<2xf32>
3026 //
3027 //    // Compute Tanh of half the logits of the values.
3028 //    %halved_logits = mhlo.multiply %logits, %half_array : tensor<2xf32>
3029 //    %tanh = "mhlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
3030 //
3031 //    // Have the result of Tanh and add 0.5.
3032 //    %halved_tanh = mhlo.multiply %tanh, %half : tensor<2xf32>
3033 //    %sigmoid = mhlo.add %halved_tanh, %half : tensor<2xf32>
3034 //
3035 class ConvertSigmoidOp : public RewritePattern {
3036  public:
ConvertSigmoidOp(MLIRContext * context)3037   explicit ConvertSigmoidOp(MLIRContext *context)
3038       : RewritePattern(
3039             TF::SigmoidOp::getOperationName(), 0, context,
3040             {mhlo::ConstantOp::getOperationName(),
3041              shape::ShapeOfOp::getOperationName(),
3042              shape::ToExtentTensorOp::getOperationName(),
3043              mhlo::DynamicBroadcastInDimOp::getOperationName(),
3044              mhlo::MulOp::getOperationName(), mhlo::TanhOp::getOperationName(),
3045              mhlo::AddOp::getOperationName()}) {}
3046 
matchAndRewrite(Operation * sigmoid_op,PatternRewriter & rewriter) const3047   LogicalResult matchAndRewrite(Operation *sigmoid_op,
3048                                 PatternRewriter &rewriter) const override {
3049     auto op = cast<TF::SigmoidOp>(sigmoid_op);
3050     Location loc = op.getLoc();
3051 
3052     // Create constant half with shape and element type same as the operand.
3053     Value operand = op.getOperand();
3054     auto operand_ty = operand.getType().cast<TensorType>();
3055     auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType());
3056     ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5);
3057     auto scalar_half = rewriter.create<ConstantOp>(loc, attr);
3058     auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter);
3059 
3060     auto scaled_input = rewriter.create<MulOp>(loc, operand, half);
3061     auto tanh_op = rewriter.create<TanhOp>(loc, scaled_input);
3062     auto mul_op = rewriter.create<MulOp>(loc, tanh_op, half);
3063     auto add_op = rewriter.create<AddOp>(loc, mul_op, half);
3064 
3065     rewriter.replaceOp(op, add_op.getResult());
3066     return success();
3067   }
3068 };
3069 
3070 // Converts the tf.Slice op into mhlo.real_dynamic_slice
3071 // TODO(disc): To recover static special case's performance with folding and
3072 // canonicalization.
3073 class ConvertSliceOpDynamic : public OpRewritePattern<TF::SliceOp> {
3074  public:
3075   using OpRewritePattern::OpRewritePattern;
3076 
matchAndRewrite(TF::SliceOp op,PatternRewriter & rewriter) const3077   LogicalResult matchAndRewrite(TF::SliceOp op,
3078                                 PatternRewriter &rewriter) const override {
3079     Location loc = op.getLoc();
3080     Value input = op.input();
3081     Value begin_indices = op.begin();
3082     Value sizes = op.size();
3083 
3084     auto input_ty = input.getType().dyn_cast<RankedTensorType>();
3085     auto begin_type = begin_indices.getType().dyn_cast<RankedTensorType>();
3086     auto size_type = sizes.getType().dyn_cast<RankedTensorType>();
3087 
3088     if (!input_ty || !begin_type || !size_type ||
3089         !begin_type.hasStaticShape() || !size_type.hasStaticShape() ||
3090         begin_type.getRank() != 1 || size_type.getRank() != 1) {
3091       return failure();
3092     }
3093     // TODO(disc): remove static shape check once folding/canonicalization func
3094     // added
3095     DenseIntElementsAttr size_attr;
3096     if (matchPattern(op.size(), m_Constant(&size_attr))) {
3097       return failure();
3098     }
3099 
3100     int rank = begin_type.getDimSize(0);
3101     auto shape_scalar_type = begin_type.getElementType();
3102     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
3103     SmallVector<Value, 4> stride_values(rank, one);
3104     SmallVector<Value, 4> end_values;
3105     SmallVector<Value, 4> begin_values;
3106     end_values.reserve(rank);
3107     for (int i = 0; i < rank; ++i) {
3108       SmallVector<Value, 4> indices;
3109       indices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
3110       auto begin_value =
3111           rewriter.create<tensor::ExtractOp>(loc, begin_indices, indices);
3112       auto size_value = rewriter.create<tensor::ExtractOp>(loc, sizes, indices);
3113       Value minus_one = rewriter.create<arith::IndexCastOp>(
3114           loc, shape_scalar_type,
3115           rewriter.create<arith::ConstantIndexOp>(loc, -1));
3116       auto is_minus_one = rewriter.create<arith::CmpIOp>(
3117           loc, arith::CmpIPredicate::eq, size_value, minus_one);
3118       Value end_value =
3119           rewriter.create<arith::AddIOp>(loc, begin_value, size_value);
3120       auto dim_value = rewriter.create<arith::IndexCastOp>(
3121           loc, shape_scalar_type,
3122           rewriter.create<tensor::DimOp>(loc, input, i));
3123       end_value = rewriter.create<mlir::arith::SelectOp>(loc, is_minus_one,
3124                                                          dim_value, end_value);
3125       auto end_value_casted = rewriter.create<arith::IndexCastOp>(
3126           loc, rewriter.getIndexType(), end_value);
3127       end_values.push_back(end_value_casted);
3128 
3129       auto begin_value_casted = rewriter.create<arith::IndexCastOp>(
3130           loc, rewriter.getIndexType(), begin_value);
3131       begin_values.push_back(begin_value_casted);
3132     }
3133     auto index_ty = rewriter.getIndexType();
3134     auto start_indices = rewriter.create<tensor::FromElementsOp>(
3135         loc,
3136         RankedTensorType::get({static_cast<int64_t>(begin_values.size())},
3137                               index_ty),
3138         begin_values);
3139     auto end_indices = rewriter.create<tensor::FromElementsOp>(
3140         loc,
3141         RankedTensorType::get({static_cast<int64_t>(end_values.size())},
3142                               index_ty),
3143         end_values);
3144     auto stride_indices = rewriter.create<tensor::FromElementsOp>(
3145         loc,
3146         RankedTensorType::get({static_cast<int64_t>(stride_values.size())},
3147                               index_ty),
3148         stride_values);
3149 
3150     auto d_slice = rewriter.create<mhlo::RealDynamicSliceOp>(
3151         loc, op.getOperation()->getResult(0).getType(), input, start_indices,
3152         end_indices, stride_indices);
3153     rewriter.replaceOp(op, d_slice.getOperation()->getResults());
3154     return success();
3155   }
3156 };
3157 
BroadcastBatchMatMulV2Operands(Value lhs,Value rhs,Location loc,Value * out_lhs,Value * out_rhs,PatternRewriter * rewriter)3158 static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
3159                                            Value *out_lhs, Value *out_rhs,
3160                                            PatternRewriter *rewriter) {
3161   // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is:
3162   // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS]
3163   // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS]
3164   // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS]
3165   // To perform the matmul, we need to first broadcast lhs and rhs to a common
3166   // set of leading dimensions before doing the actual matmul.
3167   // That's what the code below does.
3168   // In particular, we populate out_lhs and out_rhs to have dimension structure:
3169   // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS]
3170   // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS]
3171   // To do this, we need to calculate those output shapes, which involves
3172   // slicing off the leading batch dims of each operand, broadcasting them,
3173   // then concatenating the broadcasted leading dims back to the row/col dims.
3174   // Finally, we create a TF::BroadcastTo op that does the actual broadcast.
3175 
3176   // TODO(silvasean): Reduce duplication across reified shape calculations and
3177   // the static computation of output types needed to create ops.
3178   Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
3179   Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
3180   Value const_neg2 =
3181       rewriter->create<arith::ConstantOp>(loc, rewriter->getIndexAttr(-2));
3182   auto shape_type = shape::ShapeType::get(rewriter->getContext());
3183   auto lhs_splitted = rewriter->create<shape::SplitAtOp>(
3184       loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2);
3185   auto rhs_splitted = rewriter->create<shape::SplitAtOp>(
3186       loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2);
3187   auto lhs_type = lhs.getType().cast<RankedTensorType>();
3188   auto rhs_type = rhs.getType().cast<RankedTensorType>();
3189   // The last two dimensions are the matrix row/col dimensions. Don't broadcast
3190   // them.
3191   SmallVector<int64_t, 6> result_batch_shape_compile_time_extents;
3192   mlir::OpTrait::util::getBroadcastedShape(
3193       lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2),
3194       result_batch_shape_compile_time_extents);
3195   auto result_batch_shape = rewriter->create<shape::BroadcastOp>(
3196       loc, shape_type, lhs_splitted.getHead(), rhs_splitted.getHead(),
3197       /*error=*/nullptr);
3198   // Lambda which handles the broadcasting of one side to the common
3199   // leading-batch dimensions.
3200   auto broadcast_one_side = [&](Value side, RankedTensorType type,
3201                                 Value tail_shape, Value *out_side) {
3202     ArrayRef<int64_t> matrix_dims = type.getShape().take_back(2);
3203     auto result_shape = result_batch_shape_compile_time_extents;
3204     result_shape.append(matrix_dims.begin(), matrix_dims.end());
3205     auto result_type =
3206         RankedTensorType::get(result_shape, type.getElementType());
3207     auto shape =
3208         rewriter->create<shape::ConcatOp>(loc, shape_type, result_batch_shape, tail_shape);
3209     auto shape_tensor = rewriter->create<shape::ToExtentTensorOp>(
3210         loc,
3211         RankedTensorType::get({static_cast<int64_t>(result_shape.size())},
3212                               rewriter->getIndexType()),
3213         shape);
3214     *out_side = rewriter->create<TF::BroadcastToOp>(loc, result_type, side,
3215                                                     shape_tensor);
3216   };
3217   broadcast_one_side(lhs, lhs_type, lhs_splitted.getTail(), out_lhs);
3218   broadcast_one_side(rhs, rhs_type, rhs_splitted.getTail(), out_rhs);
3219 }
3220 
3221 class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
3222  public:
3223   // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved
3224   // to CHLO and it is missing legalization to MHLO. Once that is done, this
3225   // pattern's benefit can be changed back to one as well as the fallback
3226   // lowering pattern for the op can be removed.
3227   //
3228   // Set benefit of this pattern to zero to prefer the fallback pattern when
3229   // available and applicable. That pattern avoids broadcast on operands and is
3230   // therefore faster.
3231   //
3232   // Native legalization for BatchMatMulV3 needs to be added as well.
ConvertBatchMatMulV2Op(MLIRContext * context)3233   explicit ConvertBatchMatMulV2Op(MLIRContext *context)
3234       : OpRewritePattern<TF::BatchMatMulV2Op>(context, /*benefit=*/0) {}
3235 
matchAndRewrite(TF::BatchMatMulV2Op op,PatternRewriter & rewriter) const3236   LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
3237                                 PatternRewriter &rewriter) const override {
3238     Value lhs = op.x();
3239     Value rhs = op.y();
3240     auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
3241     auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
3242     if (!lhs_type || !rhs_type) return failure();
3243     if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
3244       lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
3245     }
3246     if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
3247       rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
3248     }
3249 
3250     // Broadcast both operands.
3251     BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs,
3252                                    &rewriter);
3253     lhs_type = lhs.getType().cast<RankedTensorType>();
3254     rhs_type = rhs.getType().cast<RankedTensorType>();
3255     assert(lhs_type.getRank() == rhs_type.getRank());
3256     int64_t rank = lhs_type.getRank();
3257     auto batch_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2));
3258     auto lhs_contracting_dimensions = llvm::to_vector<4>(
3259         llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}));
3260     auto rhs_contracting_dimensions = llvm::to_vector<4>(
3261         llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}));
3262     auto dimension_numbers = DotDimensionNumbersAttr::get(
3263         rewriter.getContext(),
3264         /*lhs_batching_dimensions=*/batch_dimensions,
3265         /*rhs_batching_dimensions=*/batch_dimensions,
3266         /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
3267         /*rhs_contracting_dimensions=*/rhs_contracting_dimensions);
3268     // TODO(silvasean): Emit shape checks for contracting dimensions.
3269     // (The batch dimensions are checked by the broadcasting logic)
3270     rewriter.replaceOpWithNewOp<DotGeneralOp>(op, op.getType(), lhs, rhs,
3271                                               dimension_numbers,
3272                                               /*precision_config=*/nullptr);
3273     return success();
3274   }
3275 };
3276 
3277 // Converts the tf.Split op into a series of HLO slice ops when the tensor to be
3278 // split has fully static shape and the dimension to split is a constant.
3279 //
3280 // The main logic of this pattern is to calculate the index start and end range
3281 // for each slice. And this happens only on the dimension to be split; for all
3282 // other dimensions, all resultant slices' index start and end range covers the
3283 // input tensor's full range. Strides for all resultant slices are all one.
3284 //
3285 // For example, the following source IR:
3286 //
3287 //   %dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3288 //   %0:3 = "tf.Split"(%dim, %input) : (tensor<i32>, tensor<4x6xf32>) ->
3289 //                (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
3290 //
3291 // will be converted into:
3292 //
3293 //   %0 = "mhlo.slice"(%input) {
3294 //             limit_indices = dense<[4, 2]> : tensor<2xi64>,
3295 //             start_indices = dense<0> : tensor<2xi64>,
3296 //             strides = dense<1> : tensor<2xi64>} :
3297 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
3298 //   %1 = "mhlo.slice"(%input) {
3299 //             limit_indices = dense<4> : tensor<2xi64>,
3300 //              start_indices = dense<[0, 2]> : tensor<2xi64>,
3301 //            strides = dense<1> : tensor<2xi64>} :
3302 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
3303 //    %2 = "mhlo.slice"(%input) {
3304 //            limit_indices = dense<[4, 6]> : tensor<2xi64>,
3305 //            start_indices = dense<[0, 4]> : tensor<2xi64>,
3306 //             strides = dense<1> : tensor<2xi64>} :
3307 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
3308 // TODO(antiagainst): consider lowering into TF ops so the pattern can be more
3309 // applicable.
3310 class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
3311  public:
3312   using OpRewritePattern::OpRewritePattern;
3313 
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const3314   LogicalResult matchAndRewrite(TF::SplitOp op,
3315                                 PatternRewriter &rewriter) const override {
3316     // We can only split along static dimensions.
3317     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3318     if (!input_type) return failure();
3319 
3320     // We can only match when the split dimension is a constant scalar.
3321     DenseIntElementsAttr split_dim_attr;
3322     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3323       return failure();
3324 
3325     // Get the dimension we are splitting at. Offset properly if it's negative.
3326     int64_t input_rank = input_type.getRank();
3327     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3328     if (dim_index < 0) dim_index += input_rank;
3329 
3330     // Calculate the dimension size for each slice along the split dimension.
3331     int64_t input_dim_size = input_type.getDimSize(dim_index);
3332     // If we are splitting along the dynamic dimension then we cannot compute
3333     // the static dimension length.
3334     if (ShapedType::isDynamic(input_dim_size)) return failure();
3335 
3336     int64_t num_splits = op.getNumResults();
3337     int64_t slice_size = input_dim_size / num_splits;
3338 
3339     // Get each slice's type.
3340     auto slice_shape = llvm::to_vector<4>(input_type.getShape());
3341     slice_shape[dim_index] = slice_size;
3342     Type slice_type =
3343         RankedTensorType::get(slice_shape, input_type.getElementType());
3344 
3345     // Parameters for constructing each slice.
3346     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3347     auto end_indices = llvm::to_vector<4>(input_type.getShape());
3348     SmallVector<int64_t, 4> strides(input_rank, 1);
3349 
3350     // All HLO slice results used to replace the original tf.Split op.
3351     SmallVector<Value, 4> slices;
3352     slices.reserve(num_splits);
3353 
3354     for (int i = 0; i < num_splits; ++i) {
3355       begin_indices[dim_index] = i * slice_size;
3356       end_indices[dim_index] = (i + 1) * slice_size;
3357       slices.push_back(
3358           rewriter.create<SliceOp>(op.getLoc(), slice_type, op.value(),
3359                                    GetI64ElementsAttr(begin_indices, &rewriter),
3360                                    GetI64ElementsAttr(end_indices, &rewriter),
3361                                    GetI64ElementsAttr(strides, &rewriter)));
3362     }
3363 
3364     rewriter.replaceOp(op, slices);
3365     return success();
3366   }
3367 };
3368 
3369 // Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the
3370 // dimension to split is a constant.
3371 // TODO(disc): To recover static special case's performance with folding and
3372 // canonicalization. delete ConvertSplitOp
3373 class ConvertSplitOpDynamic : public OpRewritePattern<TF::SplitOp> {
3374  public:
3375   using OpRewritePattern::OpRewritePattern;
3376 
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const3377   LogicalResult matchAndRewrite(TF::SplitOp op,
3378                                 PatternRewriter &rewriter) const override {
3379     Location loc = op.getLoc();
3380     Value input = op.value();
3381     auto input_type = input.getType().dyn_cast<RankedTensorType>();
3382     if (!input_type) return failure();
3383     // We can only match when the split dimension is a constant scalar.
3384     DenseIntElementsAttr split_dim_attr;
3385     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3386       return failure();
3387 
3388     // Get the dimension we are splitting at. Offset properly if it's negative.
3389     int64_t input_rank = input_type.getRank();
3390     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3391     if (dim_index < 0) dim_index += input_rank;
3392 
3393     // TODO(disc): remove static shape check once folding/canonicalization func
3394     // added and ConvertSplitOp deleted. Calculate the dimension size for each
3395     // slice along the split dimension. We are splitting along the dynamic
3396     // dimension, or using static pattern transform
3397     int64_t c_input_dim_size = input_type.getDimSize(dim_index);
3398     if (!ShapedType::isDynamic(c_input_dim_size)) return failure();
3399 
3400     Value input_dim_size =
3401         rewriter.create<tensor::DimOp>(loc, input, dim_index);
3402     // Calculate the dimension size for each slice along the split dimension.
3403     int num_splits = op.getNumResults();
3404     Value num_splits_value = rewriter.create<arith::ConstantOp>(
3405         loc, rewriter.getIndexAttr(num_splits));
3406     Value slice_size =
3407         rewriter.create<arith::DivSIOp>(loc, input_dim_size, num_splits_value);
3408 
3409     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3410     Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
3411 
3412     SmallVector<Value, 4> begin_indices(input_rank, zero);
3413     SmallVector<Value, 4> end_indices;
3414     end_indices.reserve(input_rank);
3415     SmallVector<Value, 4> strides(input_rank, one);
3416     for (int i = 0; i < input_rank; ++i) {
3417       end_indices.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
3418     }
3419 
3420     // All HLO d_slice results used to replace the original tf.Split op.
3421     SmallVector<Value, 4> slices;
3422     slices.reserve(num_splits);
3423 
3424     for (int i = 0; i < num_splits; ++i) {
3425       begin_indices[dim_index] = rewriter.create<arith::MulIOp>(
3426           loc, slice_size, rewriter.create<arith::ConstantIndexOp>(loc, i));
3427       end_indices[dim_index] = rewriter.create<arith::MulIOp>(
3428           loc, slice_size, rewriter.create<arith::ConstantIndexOp>(loc, i + 1));
3429 
3430       Type index_ty = rewriter.getIndexType();
3431       auto begin_value = rewriter.create<tensor::FromElementsOp>(
3432           loc,
3433           RankedTensorType::get({static_cast<int64_t>(begin_indices.size())},
3434                                 index_ty),
3435           begin_indices);
3436       auto end_value = rewriter.create<tensor::FromElementsOp>(
3437           loc,
3438           RankedTensorType::get({static_cast<int64_t>(end_indices.size())},
3439                                 index_ty),
3440           end_indices);
3441       auto stride_value = rewriter.create<tensor::FromElementsOp>(
3442           loc,
3443           RankedTensorType::get({static_cast<int64_t>(strides.size())},
3444                                 index_ty),
3445           strides);
3446       slices.push_back(rewriter.create<RealDynamicSliceOp>(
3447           loc, op.getOperation()->getResult(i).getType(), input, begin_value,
3448           end_value, stride_value));
3449     }
3450 
3451     rewriter.replaceOp(op, slices);
3452     return success();
3453   }
3454 };
3455 
3456 // Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
3457 // be split has fully static shape and the dimension to split and split sizes
3458 // are constants.
3459 //
3460 // This is similar to the conversion for tf.Split op other than that the size of
3461 // each chunk on the dimension to split is explicitly given as an op operand
3462 // and they are not necessarily the same.
3463 //
3464 // For example, given the following IR:
3465 //
3466 // %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
3467 // %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
3468 // %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
3469 //                   (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
3470 //                   (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
3471 //
3472 // We will generate slices following slices:
3473 // %0 = "mhlo.slice"(%input) {
3474 //        limit_indices = dense<[4, 1]> : tensor<2xi64>,
3475 //        start_indices = dense<0> : tensor<2xi64>,
3476 //        strides = dense<1> : tensor<2xi64>} :
3477 //        (tensor<4x6xf32>) -> tensor<4x1xf32>
3478 // %1 = "mhlo.slice"(%input) {
3479 //        limit_indices = dense<[4, 3]> : tensor<2xi64>,
3480 //        start_indices = dense<[0, 1]> : tensor<2xi64>,
3481 //        strides = dense<1> : tensor<2xi64>} :
3482 //        (tensor<4x6xf32>) -> tensor<4x2xf32>
3483 // %2 = "mhlo.slice"(%input) {
3484 //        limit_indices = dense<[4, 6]> : tensor<2xi64>,
3485 //        start_indices = dense<[0, 3]> : tensor<2xi64>,
3486 //        strides = dense<1> : tensor<2xi64>} :
3487 //        (tensor<4x6xf32>) -> tensor<4x3xf32>
3488 class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
3489  public:
3490   using OpRewritePattern::OpRewritePattern;
3491 
matchAndRewrite(TF::SplitVOp op,PatternRewriter & rewriter) const3492   LogicalResult matchAndRewrite(TF::SplitVOp op,
3493                                 PatternRewriter &rewriter) const override {
3494     // We can only split along static dimensions.
3495     // TODO(b/145731001): enhance to support dynamic-shaped inputs.
3496     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3497     if (!input_type) return failure();
3498 
3499     // We can only match when the split dimension is a constant scalar.
3500     DenseIntElementsAttr split_dim_attr;
3501     if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3502       return failure();
3503 
3504     // We can only match when the split sizes is a constant int vector.
3505     DenseIntElementsAttr split_sizes_attr;
3506     if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
3507       return failure();
3508 
3509     // Get each chunck's size along the dimension to split. It may contain
3510     // dynamic sizes and we need to update it if so.
3511     SmallVector<int64_t, 4> split_sizes;
3512     int64_t total_dim_size = 0;  // Total dimension size assigned to splits
3513     llvm::Optional<int> dynamic_dim_index;
3514     split_sizes.reserve(
3515         split_sizes_attr.getType().cast<ShapedType>().getNumElements());
3516     for (auto dim : llvm::enumerate(split_sizes_attr)) {
3517       int64_t dim_val = dim.value().getSExtValue();
3518       split_sizes.push_back(dim_val);
3519       if (dim_val == ShapedType::kDynamicSize) {
3520         // We cannot have more than one dynamic dimension.
3521         assert(!dynamic_dim_index && "invalid split sizes");
3522         dynamic_dim_index = dim.index();
3523       } else {
3524         total_dim_size += dim_val;
3525       }
3526     }
3527 
3528     // Get the dimension we are splitting at. Offset properly if it's negative.
3529     int64_t input_rank = input_type.getRank();
3530     int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3531     if (dim_index < 0) dim_index += input_rank;
3532 
3533     int64_t input_dim_size = input_type.getDimSize(dim_index);
3534     if (ShapedType::isDynamic(input_dim_size)) return failure();
3535 
3536     assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
3537             (!dynamic_dim_index && total_dim_size == input_dim_size)) &&
3538            "invalid split sizes");
3539 
3540     // Update the dynamic dimension with calculated concrete size.
3541     if (dynamic_dim_index)
3542       split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
3543 
3544     // Parameters for constructing each slice.
3545     SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3546     auto end_indices = llvm::to_vector<4>(input_type.getShape());
3547     SmallVector<int64_t, 4> strides(input_rank, 1);
3548 
3549     // All HLO slice results used to replace the original tf.Split op.
3550     SmallVector<Value, 4> slices;
3551     slices.reserve(op.getNumResults());
3552 
3553     for (int i = 0, end = op.getNumResults(); i < end; ++i) {
3554       end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
3555       slices.push_back(rewriter.create<mhlo::SliceOp>(
3556           op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
3557           GetI64ElementsAttr(end_indices, &rewriter),
3558           GetI64ElementsAttr(strides, &rewriter)));
3559       // Prepare the begin indice for the next slice.
3560       begin_indices[dim_index] = end_indices[dim_index];
3561     }
3562 
3563     rewriter.replaceOp(op, slices);
3564     return success();
3565   }
3566 };
3567 
3568 // Converts StridedSlice op to HLO Slice op along with Reverse op to handle
3569 // negative strides and Reshape op to update the output shape. Indices and
3570 // strides operands are converted to attributes with non-negative indexing.
3571 //
3572 // If the begin input is not a compile time constant, the begin input needs to
3573 // be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this
3574 // case, strides must have a known value of 1 (otherwise we have insufficient
3575 // information to conform to XLA's op semantics).
3576 //
3577 // For example with an op like following,
3578 //   tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
3579 //     : tensor<AxBxf32> -> tensor<Pxf32>
3580 //
3581 // If the %begin input is constant, output would be:
3582 //   %reversed = "mhlo.Reverse" (%input) {dimensions = ...}
3583 //   %sliced = "mhlo.Slice" (%input)
3584 //             {start_indices = ..., limit_indices = ..., strides = ...}
3585 //   %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
3586 //
3587 class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
3588  public:
3589   using OpRewritePattern::OpRewritePattern;
3590 
rewriteWithConstantBegin(TF::StridedSliceOp op,ArrayRef<int64_t> begin_indices,ArrayRef<int64_t> end_indices,ArrayRef<int64_t> strides,RankedTensorType input_ty,PatternRewriter & rewriter) const3591   LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op,
3592                                          ArrayRef<int64_t> begin_indices,
3593                                          ArrayRef<int64_t> end_indices,
3594                                          ArrayRef<int64_t> strides,
3595                                          RankedTensorType input_ty,
3596                                          PatternRewriter &rewriter) const {
3597     SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
3598         dims_to_reverse;
3599     int64_t input_rank = input_ty.getRank();
3600     ArrayRef<int64_t> input_shape = input_ty.getShape();
3601     hlo_begin_indices.reserve(input_rank);
3602     hlo_end_indices.reserve(input_rank);
3603     hlo_strides.reserve(input_rank);
3604 
3605     int64_t indices_elements = begin_indices.size();
3606     if (input_rank < indices_elements) return failure();
3607 
3608     // Convert from TensorFlow negative or out of range indices and strides
3609     // values to legal HLO Slice attributes.
3610     for (int i = 0, e = indices_elements; i != e; i++) {
3611       int64_t begin = begin_indices[i];
3612       int64_t end = end_indices[i];
3613       int64_t stride = strides[i];
3614 
3615       if (stride < 0) {
3616         // Negative stride means that the output values are computed starting
3617         // from end until begin. Mark the dimension for reversal before slice
3618         // and compute indices for the reversed input.
3619         dims_to_reverse.push_back(i);
3620         begin = (input_shape[i] - 1) - begin;
3621         end = (input_shape[i] - 1) - end;
3622         stride = -stride;
3623       }
3624 
3625       // Unlike TensorFlow, HLO requires begin and end values to be within
3626       // range.
3627       begin = std::max(int64_t(0), begin);
3628       end = std::max(begin, end);
3629       end = std::min(end, input_shape[i]);
3630 
3631       hlo_begin_indices.push_back(begin);
3632       hlo_end_indices.push_back(end);
3633       hlo_strides.push_back(stride);
3634     }
3635 
3636     Location loc = op.getLoc();
3637     Value input = op.input();
3638     if (!dims_to_reverse.empty())
3639       input = rewriter.create<ReverseOp>(
3640           loc, input_ty, op.input(),
3641           GetI64ElementsAttr(dims_to_reverse, &rewriter));
3642     auto sliced = rewriter.create<SliceOp>(
3643         loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter),
3644         GetI64ElementsAttr(hlo_end_indices, &rewriter),
3645         GetI64ElementsAttr(hlo_strides, &rewriter));
3646 
3647     // Reshape slice result so that the shape is updated depending on
3648     // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3649     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3650     return success();
3651   }
3652 
rewriteWithUnknownBegin(TF::StridedSliceOp op,RankedTensorType input_ty,RankedTensorType result_ty,PatternRewriter & rewriter) const3653   LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op,
3654                                         RankedTensorType input_ty,
3655                                         RankedTensorType result_ty,
3656                                         PatternRewriter &rewriter) const {
3657     // If begin and end values are dynamic, we can only support this lowering
3658     // if strides are a known value of 1.
3659     DenseIntElementsAttr sparse_strides_attr;
3660     if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) {
3661       return rewriter.notifyMatchFailure(
3662           op,
3663           "requires that strides are known when begin/end values are dynamic");
3664     }
3665     SmallVector<int64_t, 4> strides;
3666     int64_t stride_value;
3667     for (const APInt &stride : sparse_strides_attr) {
3668       if ((stride_value = stride.getSExtValue()) != 1) {
3669         return rewriter.notifyMatchFailure(op,
3670                                            "requires that strides are all 1 "
3671                                            "when begin/end values are dynamic");
3672       }
3673       strides.push_back(stride_value);
3674     }
3675 
3676     ArrayRef<int64_t> input_shape = input_ty.getShape();
3677     int last_dim = std::max(static_cast<int>(input_shape.size()) - 1, 0);
3678 
3679     // When begin/end values are dynamic, we can only support shrinking a major
3680     // axis. For instance, if there are 4 dims, we can support a
3681     // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
3682     // other.
3683     bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask());
3684     if (!shrink_axis_mask_ok)
3685       return rewriter.notifyMatchFailure(
3686           op,
3687           "requires that shrink_axis_mask, if set, refer to a major axis "
3688           "dimension (when begin/end values are dynamic)");
3689 
3690     // When begin/end values are dynamic, the ellipsis mask, if set, must refer
3691     // to the last dimension.
3692     int ellipsis_mask = op.ellipsis_mask();
3693     if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
3694       return rewriter.notifyMatchFailure(
3695           op,
3696           "requires that ellipsis_mask, if set, refer to the last dimension of "
3697           "input (when begin/end values are dynamic)");
3698 
3699     uint64_t new_axis_mask = op.new_axis_mask();
3700     if (new_axis_mask)
3701       return rewriter.notifyMatchFailure(
3702           op,
3703           "requires that new_axis_mask is either set to 0 or not set when "
3704           "begin/end values are dynamic");
3705 
3706     // In this case where the begin and end values are dynamic, we only support
3707     // cases where the number of output elements has to be equal to the number
3708     // of input elements that are sliced. Each dimension is either sliced fully
3709     // or sliced with a size of one.
3710     int output_elements = result_ty.getNumElements();
3711     int input_elements_sliced = 1;
3712 
3713     // Begin must be a ranked, 1-dimensional tensor: This is checked by the
3714     // verifier.
3715     int64_t slicing_dim_size =
3716         op.begin().getType().cast<RankedTensorType>().getDimSize(0);
3717     uint64_t begin_mask = op.begin_mask();
3718     uint64_t end_mask = op.end_mask();
3719     const int input_rank = input_shape.size();
3720     for (int d = 0; d < input_rank; ++d) {
3721       // Each dimension is either sliced fully or has size of one.
3722       if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) ||
3723           (d >= slicing_dim_size)) {
3724         input_elements_sliced *= input_shape[d];
3725       }
3726     }
3727     if (input_elements_sliced != output_elements) {
3728       return rewriter.notifyMatchFailure(
3729           op,
3730           "requires the number of output elements to be equal to the number of "
3731           "input elements sliced (when begin/end values are dynamic)");
3732     }
3733 
3734     SmallVector<Value, 4> slice_begin_indices;
3735     // For the dimensions that are to be sliced, all have slice sizes of 1.
3736     SmallVector<int64_t, 4> slice_sizes;
3737     auto begin_element_ty =
3738         op.begin().getType().cast<ShapedType>().getElementType();
3739     // Scalar tensor type.
3740     TensorType type = RankedTensorType::get(/*shape=*/{}, begin_element_ty);
3741     Location loc = op.getLoc();
3742     auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter);
3743     for (int d = 0; d < input_rank; ++d) {
3744       if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) ||
3745           (d >= slicing_dim_size)) {
3746         slice_begin_indices.push_back(zero);
3747         slice_sizes.push_back(input_shape[d]);
3748         continue;
3749       }
3750 
3751       auto index = rewriter.create<SliceOp>(
3752           loc, op.begin(), GetI64ElementsAttr({d}, &rewriter),
3753           GetI64ElementsAttr({d + 1}, &rewriter),
3754           GetI64ElementsAttr({1}, &rewriter));
3755       // Convert index to scalar.
3756       auto reshaped_index = rewriter.create<ReshapeOp>(loc, type, index);
3757       // If the index is negative, wrap it around with dimension size.
3758       auto index_negative =
3759           rewriter.create<TF::LessOp>(loc, reshaped_index, zero);
3760       auto input_val = GetScalarConstOfType(begin_element_ty, loc,
3761                                             input_shape[d], &rewriter);
3762       auto wrapped_index =
3763           rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
3764       auto final_index = rewriter.create<SelectOp>(
3765           loc, type, index_negative, wrapped_index, reshaped_index);
3766       slice_begin_indices.push_back(final_index);
3767       slice_sizes.push_back(1);
3768     }
3769 
3770     auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter);
3771     auto sliced_type =
3772         RankedTensorType::get(slice_sizes, op.getType().getElementType());
3773     // This must be an xla DynamicSlice op due to the inputs that aren't
3774     // constant.
3775     auto sliced = rewriter.create<DynamicSliceOp>(
3776         loc, sliced_type, op.input(), slice_begin_indices, slice_sizes_attr);
3777 
3778     // Reshape slice result so that the shape is updated depending on
3779     // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3780     rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3781     return success();
3782   }
3783 
matchAndRewrite(TF::StridedSliceOp op,PatternRewriter & rewriter) const3784   LogicalResult matchAndRewrite(TF::StridedSliceOp op,
3785                                 PatternRewriter &rewriter) const override {
3786     // Input shape needs to be static to convert negative indices in TensorFlow
3787     // to absolute indices required by HLO.
3788     //
3789     // TODO(hinsu): Relax this constraint for ops without negative indices and
3790     // strides.
3791     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
3792     if (!input_ty || !input_ty.hasStaticShape()) return failure();
3793 
3794     // Output shape needs to be static to apply 'new_axis_mask' or
3795     // 'shrink_axis_mask' by reshaping tensor after slice.
3796     //
3797     // TODO(hinsu): Relax this constraint for ops without the above masks.
3798     auto result_ty = op.getType().dyn_cast<RankedTensorType>();
3799     if (!result_ty || !result_ty.hasStaticShape()) return failure();
3800 
3801     DenseIntElementsAttr sparse_begin_attr, sparse_end_attr;
3802     if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) ||
3803         !matchPattern(op.end(), m_Constant(&sparse_end_attr))) {
3804       return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter);
3805     }
3806 
3807     SmallVector<int64_t, 4> begin_indices, end_indices, strides;
3808     if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) {
3809       return failure();
3810     }
3811     return rewriteWithConstantBegin(op, begin_indices, end_indices, strides,
3812                                     input_ty, rewriter);
3813   }
3814 };
3815 
3816 // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
3817 //
3818 // tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does
3819 // the reverse: it propagates the graident for the sliced tensor to the original
3820 // input tensor by doing padding with zeros. The main logic is calculating the
3821 // indices and strides for padding.
3822 class ConvertStridedSliceGradOp
3823     : public OpRewritePattern<TF::StridedSliceGradOp> {
3824  public:
3825   using OpRewritePattern::OpRewritePattern;
3826 
matchAndRewrite(TF::StridedSliceGradOp op,PatternRewriter & rewriter) const3827   LogicalResult matchAndRewrite(TF::StridedSliceGradOp op,
3828                                 PatternRewriter &rewriter) const override {
3829     // We need constant input shape to perform padding calculations later.
3830     DenseIntElementsAttr input_shape_attr;
3831     if (!matchPattern(op.shape(), m_Constant(&input_shape_attr)))
3832       return failure();
3833 
3834     // We also need constant begin/end indices and strides to perform padding
3835     // calculations.
3836     // Bounded shape after performing strided slice
3837     SmallVector<int64_t, 4> shape;
3838     // Bounded begin, end, and strides for strided slice
3839     SmallVector<int64_t, 4> begin_indices, end_indices, strides;
3840     if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices,
3841                                          &strides))
3842       return failure();
3843 
3844     Value grad = op.dy();
3845     Type element_type = grad.getType().cast<ShapedType>().getElementType();
3846 
3847     // Perform reshape to undo any new/shrink axes done by strided slice.
3848     grad = rewriter.create<mhlo::ReshapeOp>(
3849         op.getLoc(), RankedTensorType::get(shape, element_type), grad);
3850 
3851     SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
3852     SmallVector<int64_t, 4> dims_to_reverse;
3853     padding_low.reserve(shape.size());
3854     padding_high.reserve(shape.size());
3855     padding_interm.reserve(shape.size());
3856 
3857     // Prepare padding parameters for each dimension.
3858     for (int i = 0, e = shape.size(); i < e; ++i) {
3859       int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue();
3860       if (strides[i] > 0) {
3861         padding_low.push_back(begin_indices[i]);
3862         padding_interm.push_back(strides[i] - 1);
3863 
3864         // Pad the upper dimension up to the expected input shape. It's not
3865         // sufficient simply to use end_indices[i] to compute the padding in
3866         // cases where the stride does not divide evenly into the interval
3867         // between begin_indices[i] and end_indices[i].
3868         int64_t size =
3869             padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
3870         padding_high.push_back(input_dim - size);
3871       } else {
3872         dims_to_reverse.push_back(i);
3873         padding_high.push_back(input_dim - begin_indices[i] - 1);
3874         padding_interm.push_back(-strides[i] - 1);
3875 
3876         // Pad the lower dimension up to the expected input shape.
3877         int64_t size =
3878             padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
3879         padding_low.push_back(input_dim - size);
3880       }
3881     }
3882 
3883     if (!dims_to_reverse.empty()) {
3884       grad = rewriter.create<mhlo::ReverseOp>(
3885           op.getLoc(), grad.getType(), grad,
3886           GetI64ElementsAttr(dims_to_reverse, &rewriter));
3887     }
3888 
3889     auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
3890     rewriter.replaceOpWithNewOp<mhlo::PadOp>(
3891         op, op.getType(), grad, zero,
3892         GetI64ElementsAttr(padding_low, &rewriter),
3893         GetI64ElementsAttr(padding_high, &rewriter),
3894         GetI64ElementsAttr(padding_interm, &rewriter));
3895     return success();
3896   }
3897 };
3898 
3899 /// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and
3900 /// offset applied to generate the range values. The output tensor needs to
3901 /// have a static shape.
3902 ///
3903 /// For example an op like the following:
3904 ///   %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"}
3905 ///      : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
3906 ///
3907 /// Output would be:
3908 ///   %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
3909 ///   %scaled = "mhlo.multiply"(%iota, %delta)
3910 ///       {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
3911 ///       (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
3912 ///   %result = "mhlo.add"(%scaled, %offset)
3913 ///       {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
3914 ///       (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
3915 ///
3916 /// Implementation is defined in C++ due to no type interface for the iota op.
3917 class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
3918   using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
3919 
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const3920   LogicalResult matchAndRewrite(TF::RangeOp op,
3921                                 PatternRewriter &rewriter) const override {
3922     auto result = op.getResult();
3923     auto result_type = result.getType();
3924     if (!result_type.cast<ShapedType>().hasStaticShape()) {
3925       return failure();
3926     }
3927 
3928     auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
3929                                         rewriter.getI64IntegerAttr(0));
3930     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
3931         op.getLoc(), result_type, iota, op.delta(),
3932         hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
3933     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
3934         op, result_type, scaled, op.start(),
3935         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
3936     return success();
3937   }
3938 };
3939 
3940 // Converts RangeOp for cases with the length is a dynamic value. The shape of
3941 // the resulting tensor computed, then the start and delta is used with the
3942 // dynamic_iota value to compute the final range value.
3943 //
3944 // For example, the resulting range op value:
3945 //   %range = "tf.range"(%start, %limit, %delta)
3946 //
3947 // Is converted to the following.
3948 //   %start + %delta * iota(ceil(abs((%limit - %start) / %delta))
3949 //
3950 // Implementation is defined in C++ due to the complicated type behavior.
3951 class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
3952   using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
3953 
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const3954   LogicalResult matchAndRewrite(TF::RangeOp op,
3955                                 PatternRewriter &rewriter) const override {
3956     auto result = op.getResult();
3957     auto result_type = result.getType().cast<ShapedType>();
3958     if (result_type.hasStaticShape()) {
3959       return failure();
3960     }
3961 
3962     Value start = op.start();
3963     Value delta = op.delta();
3964     Value limit = op.limit();
3965 
3966     // To compute the length we need to use floating point calculations so that
3967     // ceil can be computed for the number of steps.
3968     auto compute_element_type =
3969         getElementTypeOrSelf(start.getType()).isa<FloatType>()
3970             ? getElementTypeOrSelf(start.getType())
3971             : rewriter.getF64Type();
3972     auto compute_type = RankedTensorType::get(
3973         limit.getType().cast<ShapedType>().getShape(), compute_element_type);
3974 
3975     // Compute the length of the sequence we are going to need. This includes
3976     // some conversion to float for the operations.
3977     //
3978     // %size = ceil(abs((%limit - %start) / %delta))
3979     auto range = rewriter.create<mhlo::SubtractOp>(op.getLoc(), limit, start);
3980     auto abs = rewriter.create<mhlo::AbsOp>(op.getLoc(), range);
3981 
3982     // Delta is not necessarily the same type as start and limit.
3983     auto abs_cast =
3984         rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, abs);
3985     auto delta_cast =
3986         rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, delta);
3987 
3988     // Compute the total number of integer steps and convert to the HLO
3989     // dimension tensor.
3990     auto normalized =
3991         rewriter.create<mhlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
3992     auto ceil = rewriter.create<mhlo::CeilOp>(op.getLoc(), normalized);
3993     auto steps = rewriter.create<mhlo::ConvertOp>(
3994         op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil);
3995     auto reshape = rewriter.create<mhlo::ReshapeOp>(
3996         op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps);
3997 
3998     // Using the resulting length compute the correct range value:
3999     //
4000     // %range = %start + %delta * iota(%size)
4001     auto out_scalar_type =
4002         RankedTensorType::get({}, getElementTypeOrSelf(result_type));
4003     auto start_out_cast =
4004         rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, start);
4005     auto delta_out_cast =
4006         rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, delta);
4007 
4008     auto iota = rewriter.create<DynamicIotaOp>(
4009         op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
4010     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4011         op.getLoc(), result_type, iota, delta_out_cast,
4012         hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
4013     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4014         op, result_type, scaled, start_out_cast,
4015         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
4016     return success();
4017   }
4018 };
4019 
ConvertAxisAttr(Value val,ElementsAttr attr,Builder * builder)4020 ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
4021   auto int_attr = attr.cast<DenseIntElementsAttr>();
4022   auto type = val.getType().cast<ShapedType>();
4023 
4024   SmallVector<int64_t, 6> axis;
4025   axis.reserve(int_attr.getNumElements());
4026 
4027   int64_t rank = type.getRank();
4028   for (auto val : int_attr.getValues<APInt>()) {
4029     axis.push_back((val.getSExtValue() + rank) % rank);
4030   }
4031 
4032   return builder->getI64TensorAttr(axis);
4033 }
4034 
4035 /// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling
4036 /// and offset applied to generate the linspace values. The output tensor needs
4037 /// to have a static shape.  The implementation is defined in C++ because there
4038 /// is no type inference for the iota op.
4039 class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
4040   using OpRewritePattern<TF::LinSpaceOp>::OpRewritePattern;
4041 
matchAndRewrite(TF::LinSpaceOp op,PatternRewriter & rewriter) const4042   LogicalResult matchAndRewrite(TF::LinSpaceOp op,
4043                                 PatternRewriter &rewriter) const override {
4044     auto result = op.getResult();
4045     auto result_type = result.getType().dyn_cast<ShapedType>();
4046     if (!result_type || !result_type.hasStaticShape()) {
4047       return failure();
4048     }
4049 
4050     DenseIntElementsAttr num_attr;
4051     if (!matchPattern(op.num(), m_Constant(&num_attr))) {
4052       return rewriter.notifyMatchFailure(op, "Num must be a constant scalar");
4053     }
4054 
4055     if (num_attr.begin() == num_attr.end()) {
4056       return rewriter.notifyMatchFailure(op, "Num must not be empty");
4057     }
4058     int64_t num = (*num_attr.begin()).getSExtValue();
4059 
4060     // Calculate the scaling that needs to be applied to the iota.
4061     auto step_numerator = rewriter.create<chlo::BroadcastSubOp>(
4062         op.getLoc(), op.start().getType(), op.stop(), op.start(),
4063         hlo::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
4064     Value step_denominator = rewriter.create<ConvertOp>(
4065         op.getLoc(), op.num(), result_type.getElementType());
4066     if (num > 1) {
4067       Value one = GetScalarConstOfType(result_type.getElementType(),
4068                                        op.getLoc(), 1, &rewriter);
4069       step_denominator = rewriter.create<chlo::BroadcastSubOp>(
4070           op.getLoc(), step_denominator.getType(), step_denominator, one,
4071           hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
4072     }
4073     auto step = rewriter.create<chlo::BroadcastDivOp>(
4074         op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
4075         hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator,
4076                                         step_denominator));
4077 
4078     // Scale the iota and add the offset.
4079     auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
4080                                         rewriter.getI64IntegerAttr(0));
4081     auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4082         op.getLoc(), result_type, iota, step,
4083         hlo::getBroadcastDimensionsAttr(&rewriter, iota, step));
4084     rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4085         op, result_type, scaled, op.start(),
4086         hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
4087     return success();
4088   }
4089 };
4090 
4091 /// Converts a generic OpTy tensorflow op to a mhlo.reduce op over
4092 /// ReductionOp.
4093 /// `is_accumulation` controls whether it uses higher precision for the actual
4094 /// reduction. This is set to false for ops like max where there is no precision
4095 /// concerns.
4096 //
4097 // The Derived class should have a static method to return the initial value to
4098 // use for reduction:
4099 //   static Value GetInitialValue(Type reduce_element_type, Location loc,
4100 //                                PatternRewriter *rewriter);
4101 // The reduce_element_type is guaranteed to be a float, int, or complex type
4102 // suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType.
4103 template <typename Derived, typename OpTy, typename ReductionOp,
4104           bool is_accumulation = true>
4105 class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
4106   using OpRewritePattern<OpTy>::OpRewritePattern;
4107 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4108   LogicalResult matchAndRewrite(OpTy op,
4109                                 PatternRewriter &rewriter) const override {
4110     // TODO(b/141785544): Update this to not require ranked shapes.
4111     // Input shape needs to be ranked to convert negative indices in TensorFlow
4112     // to absolute indices required by HLO.
4113     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
4114     if (!input_ty) return failure();
4115     ArrayRef<int64_t> input_shape = input_ty.getShape();
4116 
4117     DenseIntElementsAttr dimensions;
4118     if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
4119       return failure();
4120 
4121     // Build the final shape from input_shape and dimensions using a bitmap
4122     // to mark the reduced dimensions.
4123     SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false);
4124     SmallVector<int64_t, 4> xla_dimensions;
4125     for (const APInt &index_raw : dimensions.getValues<APInt>()) {
4126       int64_t index = index_raw.getSExtValue();
4127       int64_t rank = input_shape.size();
4128       if ((index < -rank || index >= rank)) return failure();
4129       index = (index + rank) % rank;
4130       reduced_dimensions_bitmap[index] = true;
4131       xla_dimensions.push_back(index);
4132     }
4133 
4134     Location loc = op.getLoc();
4135     Type element_type = input_ty.getElementType();
4136 
4137     // Only float, int, and complex types are currently supported.
4138     if (!element_type.isa<FloatType>() && !element_type.isa<IntegerType>() &&
4139         !element_type.isa<ComplexType>()) {
4140       return rewriter.notifyMatchFailure(
4141           op, "element type must be float, int, or complex type");
4142     }
4143 
4144     // Convert to an accumulation type to not lose precision when doing
4145     // repeated arithmetic operations.
4146     Type reduce_element_type =
4147         is_accumulation ? GetAccumulationType(element_type) : element_type;
4148     auto casted_input =
4149         rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
4150 
4151     // Each reduction op can have a different initial value.
4152     Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter);
4153 
4154     auto reduction = rewriter.create<ReduceOp>(
4155         loc, casted_input.getResult(), init,
4156         GetI64ElementsAttr(xla_dimensions, &rewriter));
4157     BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
4158                                  &rewriter);
4159     Value result = reduction.getResult(0);
4160 
4161     // The mean op needs to divide by the product of the reduced dimensions.
4162     if (std::is_same<OpTy, TF::MeanOp>::value) {
4163       Value in_shape = rewriter.create<shape::ShapeOfOp>(loc, op.input());
4164       Value divisor_count = rewriter.create<arith::ConstantIndexOp>(loc, 1);
4165       for (size_t i = 0; i < input_shape.size(); ++i) {
4166         if (reduced_dimensions_bitmap[i]) {
4167           Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
4168           auto dim = rewriter.create<tensor::ExtractOp>(loc, in_shape, index);
4169           divisor_count =
4170               rewriter.create<arith::MulIOp>(loc, divisor_count, dim);
4171         }
4172       }
4173       // HLO ops are only defined on tensors, so we cast the divisor from
4174       // index -> i64 -> tensor<1xi64> -> tensor<i64> -> tensor<reduction type>
4175       Value divisor_casted = rewriter.create<arith::IndexCastOp>(
4176           loc, rewriter.getI64Type(), divisor_count);
4177       Value divisor_tensor = rewriter.create<tensor::FromElementsOp>(
4178           loc, RankedTensorType::get({}, rewriter.getI64Type()),
4179           divisor_casted);
4180       Value divisor = rewriter.create<ConvertOp>(
4181           loc, RankedTensorType::get({}, reduce_element_type), divisor_tensor);
4182       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
4183       result = rewriter.create<chlo::BroadcastDivOp>(loc, result, divisor,
4184                                                      broadcast_dims);
4185     }
4186 
4187     result = rewriter.create<ConvertOp>(loc, result, element_type);
4188 
4189     // Need to reshape back after the reduction if we're keeping the reduced
4190     // dimensions. Note that we do this through successive (nominally 1)
4191     // applications of the TF ExpandDims op vs a more labor intensive
4192     // reshape. Various code generation techniques benefit from the knowledge
4193     // that this is a restricted form of shape manipulation that is just adding
4194     // unit dims.
4195     if (op.keep_dims()) {
4196       for (auto &dim_is_reduced : llvm::enumerate(reduced_dimensions_bitmap)) {
4197         if (dim_is_reduced.value()) {
4198           auto index_attr = GetI32ElementsAttr(
4199               {static_cast<int>(dim_is_reduced.index())}, &rewriter);
4200           Value index = rewriter.create<arith::ConstantOp>(loc, index_attr);
4201           result = rewriter.create<TF::ExpandDimsOp>(loc, result, index);
4202         }
4203       }
4204     }
4205     rewriter.replaceOp(op, {result});
4206 
4207     return success();
4208   }
4209 };
4210 
4211 // Converts Mean op to HLO Reduce op.
4212 //
4213 //   %init = arith.constant dense<...> : tensor<T>
4214 //   %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
4215 //               {dimensions = ...}
4216 //   %divisor = arith.constant dense<...> : tensor<T>
4217 //   %mean = "mhlo.divide"(%sum, %divisor)
4218 class ConvertMeanOp
4219     : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
4220  public:
4221   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4222   static Value GetInitialValue(Type reduce_element_type, Location loc,
4223                                PatternRewriter *rewriter) {
4224     return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter);
4225   }
4226 };
4227 
4228 // Converts Sum op to HLO Reduce op.
4229 //
4230 //   %init = arith.constant dense<...> : tensor<T>
4231 //   %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
4232 //               {dimensions = ...}
4233 class ConvertSumOp
4234     : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
4235  public:
4236   using GenericConvertReductionOp::GenericConvertReductionOp;
4237 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4238   static Value GetInitialValue(Type reduce_element_type, Location loc,
4239                                PatternRewriter *rewriter) {
4240     // The neutral element of fp addition is -0.0, not 0.0: '0.0 + -0.0 = 0.0'.
4241     return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter);
4242   }
4243 };
4244 
4245 // Converts Max op to HLO Reduce op.
4246 //
4247 //   %init = arith.constant dense<...> : tensor<T>
4248 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
4249 //               {dimensions = ...}
4250 class ConvertMaxOp
4251     : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
4252                                        /* is_accumulation= */ false> {
4253  public:
4254   using GenericConvertReductionOp::GenericConvertReductionOp;
4255 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4256   static Value GetInitialValue(Type reduce_element_type, Location loc,
4257                                PatternRewriter *rewriter) {
4258     return GetScalarLimitConstOfType(reduce_element_type, loc,
4259                                      hlo::kInfinityLowest, rewriter);
4260   }
4261 };
4262 
4263 // Converts Min op to HLO Reduce op.
4264 //
4265 //   %init = arith.constant dense<...> : tensor<T>
4266 //   %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"]
4267 //               {dimensions = ...}
4268 class ConvertMinOp
4269     : public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
4270                                        /* is_accumulation= */ false> {
4271  public:
4272   using GenericConvertReductionOp::GenericConvertReductionOp;
4273 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4274   static Value GetInitialValue(Type reduce_element_type, Location loc,
4275                                PatternRewriter *rewriter) {
4276     return GetScalarLimitConstOfType(reduce_element_type, loc,
4277                                      hlo::kInfinityMax, rewriter);
4278   }
4279 };
4280 
4281 // Converts Prod op to HLO Reduce op.
4282 //
4283 //   %init = arith.constant dense<...> : tensor<T>
4284 //   %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"]
4285 //               {dimensions = ...}
4286 class ConvertProdOp
4287     : public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
4288  public:
4289   using GenericConvertReductionOp::GenericConvertReductionOp;
4290 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4291   static Value GetInitialValue(Type reduce_element_type, Location loc,
4292                                PatternRewriter *rewriter) {
4293     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4294   }
4295 };
4296 
4297 // Converts All op to HLO Reduce op.
4298 //
4299 //   %init = arith.constant dense<...> : tensor<T>
4300 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"]
4301 //               {dimensions = ...}
4302 class ConvertAllOp
4303     : public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
4304  public:
4305   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4306   static Value GetInitialValue(Type reduce_element_type, Location loc,
4307                                PatternRewriter *rewriter) {
4308     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4309   }
4310 };
4311 
4312 // Converts Any op to HLO Reduce op.
4313 //
4314 //   %init = arith.constant dense<...> : tensor<T>
4315 //   %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"]
4316 //               {dimensions = ...}
4317 class ConvertAnyOp
4318     : public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
4319  public:
4320   using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4321   static Value GetInitialValue(Type reduce_element_type, Location loc,
4322                                PatternRewriter *rewriter) {
4323     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4324   }
4325 };
4326 
4327 // Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform
4328 // a reduction on the original input and the corresponding index. The reduction
4329 // sub-computation selects the max (or min) value and the index for the value.
4330 //   Derived: is the resulting derived class of this class.
4331 //   OpTy: is TF::ArgMaxOp or TF::ArgMinOp.
4332 template <typename Derived, typename OpTy>
4333 class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
4334   using OpRewritePattern<OpTy>::OpRewritePattern;
4335 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4336   LogicalResult matchAndRewrite(OpTy op,
4337                                 PatternRewriter &rewriter) const override {
4338     RankedTensorType input_type =
4339         op.input().getType().template dyn_cast<RankedTensorType>();
4340     if (!input_type) {
4341       return failure();
4342     }
4343 
4344     Type input_element_type = input_type.getElementType();
4345     // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
4346     // tf.ArgMax doesn't support complex data types, this check can be removed.
4347     if (!input_element_type.isSignlessIntOrFloat()) return failure();
4348 
4349     Location loc = op.getLoc();
4350     Value init_value =
4351         Derived::GetInitialValue(input_element_type, loc, rewriter);
4352 
4353     RankedTensorType output_type =
4354         op.output().getType().template dyn_cast<RankedTensorType>();
4355     if (!output_type) {
4356       return rewriter.notifyMatchFailure(op, "requires known rank");
4357     }
4358 
4359     Type index_element_type = output_type.getElementType();
4360     Value index_init_value =
4361         GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
4362 
4363     RankedTensorType index_type =
4364         RankedTensorType::get(input_type.getShape(), index_element_type);
4365 
4366     llvm::Optional<int64_t> optional_axis =
4367         GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank());
4368     if (!optional_axis.has_value())
4369       return rewriter.notifyMatchFailure(op, "required axis");
4370     int64_t axis = optional_axis.getValue();
4371 
4372     IntegerAttr iota_dimension =
4373         IntegerAttr::get(rewriter.getIntegerType(64), axis);
4374     Value input_shape = rewriter.create<shape::ShapeOfOp>(loc, op.input());
4375     Value index_values = rewriter.create<DynamicIotaOp>(
4376         loc, index_type, input_shape, iota_dimension);
4377 
4378     Value operands[] = {op.input(), index_values};
4379     Value init_values[] = {init_value, index_init_value};
4380     DenseIntElementsAttr reduction_dimensions =
4381         GetI64ElementsAttr({axis}, &rewriter);
4382 
4383     auto reduction = rewriter.create<ReduceOp>(
4384         loc, llvm::ArrayRef<Value>(operands),
4385         llvm::ArrayRef<Value>(init_values), reduction_dimensions);
4386     auto direction = Derived::GetDirection();
4387     BuildArgMinMaxReductionBody(input_element_type, index_element_type,
4388                                 direction, &reduction.body(), &rewriter);
4389 
4390     rewriter.replaceOp(op, {reduction.getResult(1)});
4391     return success();
4392   }
4393 };
4394 
4395 // Converts tensorflow ArgMax op to mhlo operations. The actual
4396 // implementation is in class ConvertArgMinMaxOp:
4397 //
4398 //   %init_index = arith.constant dense<...> : tensor<T>
4399 //   %init = arith.constant dense<...> : tensor<T>
4400 //   %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
4401 //                              %init_index) ["mhlo.arg_max"]
4402 class ConvertArgMaxOp
4403     : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
4404  public:
4405   using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
4406 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)4407   static Value GetInitialValue(Type reduce_element_type, Location loc,
4408                                PatternRewriter &rewriter) {
4409     return GetScalarLimitConstOfType(reduce_element_type, loc,
4410                                      hlo::kInfinityLowest, &rewriter);
4411   }
4412 
GetDirection()4413   static ComparisonDirection GetDirection() { return ComparisonDirection::GE; }
4414 };
4415 
4416 // Converts tensorflow ArgMin op to mhlo operations. The actual
4417 // implementation is in class ConvertArgMinMaxOp:
4418 //
4419 //   %init_index = arith.constant dense<...> : tensor<T>
4420 //   %init = arith.constant dense<...> : tensor<T>
4421 //   %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
4422 //                              %init_index) ["mhlo.arg_min"]
4423 class ConvertArgMinOp
4424     : public ConvertArgMinMaxOp<ConvertArgMinOp, TF::ArgMinOp> {
4425  public:
4426   using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
4427 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)4428   static Value GetInitialValue(Type reduce_element_type, Location loc,
4429                                PatternRewriter &rewriter) {
4430     return GetScalarLimitConstOfType(reduce_element_type, loc,
4431                                      hlo::kInfinityMax, &rewriter);
4432   }
4433 
GetDirection()4434   static ComparisonDirection GetDirection() { return ComparisonDirection::LE; }
4435 };
4436 
4437 // Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with
4438 // assignment:
4439 //
4440 //   %result = "mhlo.scatter"(%tensor, %indices, %updates)
4441 //     { dimensions = ... }
4442 //
4443 template <typename Derived, typename OpTy>
4444 class ConvertTensorScatterOp : public OpRewritePattern<OpTy> {
4445  public:
4446   using OpRewritePattern<OpTy>::OpRewritePattern;
4447 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4448   LogicalResult matchAndRewrite(OpTy op,
4449                                 PatternRewriter &rewriter) const override {
4450     auto tensor_ty =
4451         op.tensor().getType().template dyn_cast<RankedTensorType>();
4452     auto indices_ty =
4453         op.indices().getType().template dyn_cast<RankedTensorType>();
4454     auto updates_ty =
4455         op.updates().getType().template dyn_cast<RankedTensorType>();
4456 
4457     if (!tensor_ty || !indices_ty || !updates_ty) return failure();
4458     // Last dimension of the indices needs to known at compile time for
4459     // computation of the 'update_window_dims' attribute in the dimensions
4460     // struct.
4461     int64_t num_index_dims = indices_ty.getShape().back();
4462     if (ShapedType::isDynamic(num_index_dims)) return failure();
4463 
4464     auto updates = op.updates();
4465 
4466     // Broadcast scalar `updates` in into expected shape as following shape:
4467     // updates.shape == indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
4468     if (updates_ty.getRank() == 0 &&
4469         (std::is_same<OpTy, TF::TensorScatterUpdateOp>::value ||
4470          std::is_same<OpTy, TF::TensorScatterAddOp>::value)) {
4471       if (!tensor_ty.hasStaticShape()) {
4472         return failure();
4473       }
4474 
4475       if (!indices_ty.hasStaticShape()) {
4476         return failure();
4477       }
4478 
4479       auto tensor_shape = tensor_ty.getShape();
4480       auto indices_shape = indices_ty.getShape();
4481       auto index_depth = indices_shape.back();
4482       llvm::SmallVector<int64_t> expected_update_shape;
4483 
4484       // create the expected update shape which scalar update is broadcasted to
4485       expected_update_shape.append(indices_shape.begin(),
4486                                    std::prev(indices_shape.end()));
4487 
4488       expected_update_shape.append(std::next(tensor_shape.begin(), index_depth),
4489                                    tensor_shape.end());
4490 
4491       auto const_type = RankedTensorType::get(
4492           {static_cast<int>(expected_update_shape.size())},
4493           rewriter.getIntegerType(64));
4494 
4495       auto const_attr = GetI64ElementsAttr(expected_update_shape, &rewriter);
4496 
4497       auto const_op =
4498           rewriter.create<TF::ConstOp>(op->getLoc(), const_type, const_attr);
4499 
4500       auto broadcast_to_type = RankedTensorType::get(
4501           llvm::makeArrayRef<int64_t>(expected_update_shape),
4502           updates_ty.getElementType());
4503 
4504       updates = rewriter.create<TF::BroadcastToOp>(
4505           op->getLoc(), broadcast_to_type, op.updates(), const_op);
4506 
4507       updates_ty = updates.getType().template dyn_cast<RankedTensorType>();
4508     }
4509 
4510     int64_t tensor_rank = tensor_ty.getRank();
4511     int64_t indices_rank = indices_ty.getRank();
4512     int64_t updates_rank =
4513         updates.getType().template dyn_cast<RankedTensorType>().getRank();
4514 
4515     int64_t window_dims = tensor_rank - num_index_dims;
4516     auto dims_attr = ScatterDimensionNumbersAttr::get(
4517         rewriter.getContext(),
4518         llvm::to_vector<4>(
4519             llvm::seq<int64_t>(updates_rank - window_dims, updates_rank)),
4520         llvm::to_vector<4>(llvm::seq<int64_t>(0, num_index_dims)),
4521         llvm::to_vector<4>(llvm::seq<int64_t>(0, num_index_dims)),
4522         indices_rank - 1);
4523 
4524     Location loc = op.getLoc();
4525     auto scatter = rewriter.create<ScatterOp>(loc, op.getType(),
4526                                               ValueRange(Value(op.tensor())),
4527                                               op.indices(), updates, dims_attr);
4528     Derived::BuildScatterBody(tensor_ty.getElementType(),
4529                               &scatter.update_computation(), loc, rewriter);
4530 
4531     rewriter.replaceOp(op, scatter.getResult(0));
4532     return success();
4533   }
4534 };
4535 
4536 class ConvertTensorScatterUpdateOp
4537     : public ConvertTensorScatterOp<ConvertTensorScatterUpdateOp,
4538                                     TF::TensorScatterUpdateOp> {
4539  public:
4540   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4541 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4542   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4543                                OpBuilder &builder) {
4544     OpBuilder::InsertionGuard guard(builder);
4545     Block *block = builder.createBlock(region);
4546     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4547     block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4548     builder.create<ReturnOp>(loc, block->getArgument(1));
4549   }
4550 };
4551 
4552 class ConvertTensorScatterAddOp
4553     : public ConvertTensorScatterOp<ConvertTensorScatterAddOp,
4554                                     TF::TensorScatterAddOp> {
4555  public:
4556   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4557 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4558   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4559                                OpBuilder &builder) {
4560     OpBuilder::InsertionGuard guard(builder);
4561     Block *block = builder.createBlock(region);
4562     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4563     block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4564     auto add_op = builder.create<AddOp>(loc, block->getArgument(0),
4565                                         block->getArgument(1));
4566     builder.create<ReturnOp>(loc, add_op.getResult());
4567   }
4568 };
4569 
4570 class ConvertTensorScatterSubOp
4571     : public ConvertTensorScatterOp<ConvertTensorScatterSubOp,
4572                                     TF::TensorScatterSubOp> {
4573  public:
4574   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4575 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4576   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4577                                OpBuilder &builder) {
4578     OpBuilder::InsertionGuard guard(builder);
4579     Block *block = builder.createBlock(region);
4580     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4581     block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4582     auto sub_op = builder.create<SubtractOp>(loc, block->getArgument(0),
4583                                              block->getArgument(1));
4584     builder.create<ReturnOp>(loc, sub_op.getResult());
4585   }
4586 };
4587 
4588 class ConvertTensorScatterMinOp
4589     : public ConvertTensorScatterOp<ConvertTensorScatterMinOp,
4590                                     TF::TensorScatterMinOp> {
4591  public:
4592   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4593 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4594   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4595                                OpBuilder &builder) {
4596     OpBuilder::InsertionGuard guard(builder);
4597     Block *block = builder.createBlock(region);
4598     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4599     block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4600     auto min_op = builder.create<MinOp>(loc, block->getArgument(0),
4601                                         block->getArgument(1));
4602     builder.create<ReturnOp>(loc, min_op.getResult());
4603   }
4604 };
4605 
4606 class ConvertTensorScatterMaxOp
4607     : public ConvertTensorScatterOp<ConvertTensorScatterMaxOp,
4608                                     TF::TensorScatterMaxOp> {
4609  public:
4610   using ConvertTensorScatterOp::ConvertTensorScatterOp;
4611 
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4612   static void BuildScatterBody(Type element_type, Region *region, Location loc,
4613                                OpBuilder &builder) {
4614     OpBuilder::InsertionGuard guard(builder);
4615     Block *block = builder.createBlock(region);
4616     Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4617     block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4618     auto max_op = builder.create<MaxOp>(loc, block->getArgument(0),
4619                                         block->getArgument(1));
4620     builder.create<ReturnOp>(loc, max_op.getResult());
4621   }
4622 };
4623 
4624 // Converts Tile op to HLO BroadcastInDim and Reshape ops.
4625 //   For shape [S1, S2] and multiples [M1, M2],
4626 //     MS1 = M1 * S1; MS2 = M2 * S2
4627 //
4628 //   %broadcast = mhlo.broadcast_in_dim(%input) {
4629 //     broadcast_dimensions = [0, 2]
4630 //   }
4631 //   %result = "mhlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
4632 //      -> tensor<MS1xMS2xf32>
4633 class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
4634  public:
4635   using OpRewritePattern::OpRewritePattern;
4636 
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const4637   LogicalResult matchAndRewrite(TF::TileOp op,
4638                                 PatternRewriter &rewriter) const override {
4639     auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
4640     if (!input_ty || !input_ty.hasStaticShape()) return failure();
4641     ArrayRef<int64_t> input_shape = input_ty.getShape();
4642     Type element_type = input_ty.getElementType();
4643 
4644     DenseIntElementsAttr multiples;
4645     if (!matchPattern(op.multiples(), m_Constant(&multiples)) ||
4646         multiples.getType().getRank() != 1)
4647       return failure();
4648 
4649     const int64_t input_shape_size = input_shape.size();
4650     if (multiples.getNumElements() != input_shape_size) return failure();
4651 
4652     SmallVector<int64_t, 8> broadcasted_shape;
4653     SmallVector<int64_t, 4> broadcast_dimensions;
4654     broadcasted_shape.reserve(input_shape.size() * 2);
4655     broadcast_dimensions.reserve(input_shape.size());
4656     for (auto multiple_and_input :
4657          llvm::zip(multiples.getValues<APInt>(), input_shape)) {
4658       int64_t multiple = std::get<0>(multiple_and_input).getSExtValue();
4659       int64_t input_size = std::get<1>(multiple_and_input);
4660 
4661       if (multiple < 0) return failure();
4662 
4663       // Line input up with the next dimension in broadcasted_shape
4664       // when broadcasting.
4665       int64_t broadcast_dim;
4666       int64_t output_size = input_size * multiple;
4667       if (input_size == 1 || multiple == 1) {
4668         // Special case for when normal broadcasting will just work.
4669         broadcast_dim = broadcasted_shape.size();
4670         broadcasted_shape.push_back(output_size);
4671       } else {
4672         // Tiling will happen for this dimension during the ReshapeOp below.
4673         broadcasted_shape.push_back(multiple);
4674         broadcast_dim = broadcasted_shape.size();
4675         broadcasted_shape.push_back(input_size);
4676       }
4677       broadcast_dimensions.push_back(broadcast_dim);
4678     }
4679     Location loc = op.getLoc();
4680     Type broadcasted_type =
4681         RankedTensorType::get(broadcasted_shape, element_type);
4682     Type output_type = op.getType();
4683 
4684     Value result = rewriter.create<BroadcastInDimOp>(
4685         loc, broadcasted_type, op.input(),
4686         GetI64ElementsAttr(broadcast_dimensions, &rewriter));
4687 
4688     if (output_type != broadcasted_type) {
4689       result = rewriter.create<ReshapeOp>(loc, output_type, result);
4690     }
4691 
4692     rewriter.replaceOp(op, {result});
4693 
4694     return success();
4695   }
4696 };
4697 
4698 // Converts the tf.TileOp op into mhlo.dynamic_reshape
4699 // TODO(disc): To recover static special case's performance with folding and
4700 // canonicalization.
4701 class ConvertTileOpDynamic : public OpRewritePattern<TF::TileOp> {
4702  public:
4703   using OpRewritePattern::OpRewritePattern;
4704   // clang-format off
4705   // Converts Tile op to HLO DBroadcastInDim and DReshape ops.
4706   //   For shape [S1, S2] and multiples [M1, M2],
4707   //     MS1 = M1 * S1; MS2 = M2 * S2
4708   //
4709   //   %out_dim_size = [S1, M1, S2, M2]
4710   //   %broadcast_dimensions = [1, 3];
4711   //   %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions);
4712   //   %shape = [MS1, MS2]
4713   //   %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor<S1xM1xS2xM2xf32>) -> tensor<MS1xMS2xf32>
4714   // clang-format on
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const4715   LogicalResult matchAndRewrite(TF::TileOp op,
4716                                 PatternRewriter &rewriter) const final {
4717     Location loc = op.getLoc();
4718     Value input = op.input();
4719     Value multiples = op.multiples();
4720     auto input_ty = input.getType().dyn_cast<RankedTensorType>();
4721     if (!input_ty) return failure();
4722     // TODO(disc): Remove this constraint once fold and canonicalization
4723     // implemented.
4724     if (input_ty.hasStaticShape()) return failure();
4725 
4726     Type element_type = input_ty.getElementType();
4727     int64_t input_rank = input_ty.getRank();
4728     SmallVector<Value, 4> input_shape_values;
4729     for (int64_t i = 0; i < input_rank; ++i) {
4730       auto dim_size = input_ty.getDimSize(i);
4731       if (dim_size == ShapedType::kDynamicSize) {
4732         input_shape_values.push_back(
4733             rewriter.create<tensor::DimOp>(loc, input, i));
4734       } else {
4735         input_shape_values.push_back(rewriter.create<arith::ConstantOp>(
4736             loc, rewriter.getIndexAttr(dim_size)));
4737       }
4738     }
4739 
4740     auto multiples_ty = multiples.getType().dyn_cast<RankedTensorType>();
4741     int64_t multiples_rank = multiples_ty.getRank();
4742     // rank of multiples input of tf.TileOp must be 1
4743     if (multiples_rank != 1) return failure();
4744     // multiples input of tf.TileOp must be fixed shaped
4745     if ((!multiples_ty.hasStaticShape()) ||
4746         (multiples_ty.getDimSize(0) != input_rank)) {
4747       return failure();
4748     }
4749     Type index_ty = rewriter.getIndexType();
4750     // %out_dim_size
4751     SmallVector<Value, 4> out_dim_size;
4752     out_dim_size.reserve(input_rank * 2);
4753     for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) {
4754       Value index = rewriter.create<arith::ConstantOp>(
4755           loc, rewriter.getIndexAttr(dim_idx));
4756       Value multiples_size =
4757           rewriter.create<tensor::ExtractOp>(loc, multiples, ValueRange{index});
4758       Value multiples_size_casted =
4759           rewriter.create<arith::IndexCastOp>(loc, index_ty, multiples_size);
4760       out_dim_size.push_back(multiples_size_casted);
4761       out_dim_size.push_back(input_shape_values[dim_idx]);
4762     }
4763     SmallVector<int64_t, 4> broadcast_dimensions;
4764     broadcast_dimensions.reserve(input_rank);
4765     for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) {
4766       broadcast_dimensions.push_back(1 + 2 * dim_idx);
4767     }
4768     auto broadcast_dims_attr =
4769         GetI64ElementsAttr(broadcast_dimensions, &rewriter);
4770 
4771     Value out_dim_size_tensor = rewriter.create<tensor::FromElementsOp>(
4772         loc,
4773         RankedTensorType::get({static_cast<int64_t>(out_dim_size.size())},
4774                               index_ty),
4775         out_dim_size);
4776     SmallVector<int64_t, 4> broadcast_shape(input_rank * 2,
4777                                             ShapedType::kDynamicSize);
4778     RankedTensorType broadcast_type =
4779         RankedTensorType::get(broadcast_shape, element_type);
4780     Value broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
4781         loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr);
4782 
4783     // %shape = [MS1, MS2]
4784     SmallVector<Value, 4> shape_values;
4785     shape_values.reserve(input_rank);
4786     for (int64_t i = 0; i < input_rank; ++i) {
4787       Value dim_size_value = rewriter.create<mlir::arith::MulIOp>(
4788           loc, out_dim_size[2 * i], out_dim_size[2 * i + 1]);
4789       shape_values.push_back(dim_size_value);
4790     }
4791     Value shape = rewriter.create<tensor::FromElementsOp>(
4792         loc, RankedTensorType::get({input_rank}, index_ty), shape_values);
4793     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
4794                                                         broadcast, shape);
4795     return success();
4796   }
4797 };
4798 
4799 template <typename OpTy, int num_dims>
4800 class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
4801  public:
4802   using OpRewritePattern<OpTy>::OpRewritePattern;
4803 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4804   LogicalResult matchAndRewrite(OpTy op,
4805                                 PatternRewriter &rewriter) const override {
4806     Location loc = op.getLoc();
4807 
4808     Type element_type =
4809         op.orig_input().getType().template cast<TensorType>().getElementType();
4810 
4811     // Compute paddings using the original input and kernel shape and strides.
4812     // Here, ReduceWindow op as used as the MaxPool op is lowered to the
4813     // ReduceWindow op.
4814     auto input_ty =
4815         op.orig_input().getType().template dyn_cast<RankedTensorType>();
4816     if (!input_ty) return failure();
4817     DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
4818         input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
4819 
4820     auto result = rewriter.create<SelectAndScatterOp>(
4821         loc, op.getType(), op.orig_input(), op.grad(),
4822         GetScalarConstOfType(element_type, loc, 0, &rewriter),
4823         GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
4824         paddings_attr);
4825 
4826     BuildReduceBody<AddOp>(element_type, &result.scatter(), &rewriter);
4827     {
4828       OpBuilder::InsertionGuard guard(rewriter);
4829       Block *block = rewriter.createBlock(&result.select());
4830 
4831       // Block arguments are scalars of the given element type.
4832       Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4833       block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4834 
4835       auto reducer = rewriter.create<CompareOp>(loc, block->getArgument(0),
4836                                                 block->getArgument(1),
4837                                                 ComparisonDirection::GE);
4838       rewriter.create<ReturnOp>(loc, reducer.getResult());
4839     }
4840 
4841     rewriter.replaceOp(op, {result});
4842 
4843     return success();
4844   }
4845 };
4846 
4847 using ConvertMaxPool2DGradOp =
4848     ConvertMaxPoolGradOp<TF::MaxPoolGradOp, /*num_dims=*/4>;
4849 using ConvertMaxPool3DGradOp =
4850     ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
4851 
4852 // Converts tf.Conv?DBackpropInputOp into:
4853 //   %rev_filter = "mhlo.reverse"(%filter)
4854 //   %result = "mhlo.convolution"(%out_backprop, %rev_filter)
4855 template <typename OpTy, int num_spatial_dims>
4856 class ConvertConvBackpropInputOp : public OpRewritePattern<OpTy> {
4857  public:
4858   using OpRewritePattern<OpTy>::OpRewritePattern;
4859 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4860   LogicalResult matchAndRewrite(OpTy op,
4861                                 PatternRewriter &rewriter) const override {
4862     // Unpack all of the attributes.
4863     tensorflow::TensorFormat data_format;
4864     if (!FormatFromString(op.data_format().str(), &data_format))
4865       return op.emitOpError("invalid data format");
4866     constexpr int num_dims = num_spatial_dims + 2;
4867     int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
4868 
4869     tensorflow::Padding padding;
4870     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
4871       return failure();
4872 
4873     auto out_backprop_ty =
4874         op.out_backprop().getType().template dyn_cast<RankedTensorType>();
4875     auto filter_ty =
4876         op.filter().getType().template dyn_cast<RankedTensorType>();
4877 
4878     // With the exception of out_backprop's batch dimension, out_backprop and
4879     // filter need to have static shape. Filter is validated here, out_backprop
4880     // is mostly validated at use.
4881     if (!out_backprop_ty || !filter_ty || !filter_ty.hasStaticShape())
4882       return failure();
4883 
4884     // Compute input_shape by supporting either:
4885     //   1) Fully static shapes, represented as constants.
4886     //   2) Static shapes with a dynamic batch dimension, represented as
4887     //      1D tf.Pack of a batch dimension (can be static or dynamic)
4888     //      and other dimensions (can only be static), for example:
4889     //      "tf.Pack"(%142, %cst_301, %cst_301, %cst_300) {axis = 0 : i64, ...}
4890     std::vector<int64_t> input_shape;
4891     DenseIntElementsAttr input_shape_attr;
4892     if (matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) &&
4893         input_shape_attr.getType().getRank() == 1) {
4894       input_shape.insert(input_shape.end(),
4895                          input_shape_attr.getValues<int32_t>().begin(),
4896                          input_shape_attr.getValues<int32_t>().end());
4897     } else {
4898       auto pack = op.input_sizes().template getDefiningOp<TF::PackOp>();
4899       if (!pack || pack.axis() != 0) return failure();
4900       auto pack_ty = pack.getType().template dyn_cast<RankedTensorType>();
4901       if (!pack_ty || pack_ty.getRank() != 1) return failure();
4902       for (auto i = 0; i < pack_ty.getDimSize(0); ++i) {
4903         if (i == batch_dim) {
4904           // We don't use the batch dimension below, so we don't care about
4905           // its size. Might as well populate it with -1.
4906           input_shape.push_back(ShapedType::kDynamicSize);
4907         } else {
4908           DenseIntElementsAttr input_dims_attr;
4909           if (matchPattern(pack.values()[i], m_Constant(&input_dims_attr)) &&
4910               input_dims_attr.getType().getRank() == 0) {
4911             input_shape.push_back(input_dims_attr.getSplatValue<int32_t>());
4912           } else {
4913             return failure();
4914           }
4915         }
4916       }
4917     }
4918 
4919     auto dilations_attr = GetI64ElementsAttr(op.dilations());
4920     std::vector<int> dilations{
4921         dilations_attr.template getValues<int64_t>().begin(),
4922         dilations_attr.template getValues<int64_t>().end()};
4923     auto strides_attr = GetI64ElementsAttr(op.strides());
4924     std::vector<tensorflow::int32> strides{
4925         strides_attr.template getValues<int64_t>().begin(),
4926         strides_attr.template getValues<int64_t>().end()};
4927 
4928     std::vector<int64_t> explicit_paddings;
4929     if (padding == tensorflow::Padding::EXPLICIT) {
4930       // EXPLICIT padding mode and the associated attribute is limited to
4931       // Conv2DBackpropInput. So, fetch attribute by identifier instead of the
4932       // op.explicit_paddings() attribute getter.
4933       ArrayRef<Attribute> explicit_paddings_attr =
4934           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
4935       explicit_paddings.reserve(explicit_paddings_attr.size());
4936       for (Attribute explicit_padding : explicit_paddings_attr)
4937         explicit_paddings.push_back(
4938             explicit_padding.cast<IntegerAttr>().getInt());
4939     }
4940 
4941     ArrayRef<int64_t> filter_shape = filter_ty.getShape();
4942 
4943     // Compute ConvDimensionNumbers, dilation, and padding.
4944     SmallVector<int64_t, num_spatial_dims> spatial_dims;
4945     SmallVector<int64_t, num_spatial_dims> lhs_dilation;
4946     SmallVector<int64_t, num_spatial_dims> rhs_dilation;
4947     SmallVector<int64_t, num_spatial_dims * 2> paddings;
4948 
4949     for (int i : llvm::seq<int>(0, num_spatial_dims)) {
4950       const int64_t spatial_dim =
4951           GetTensorSpatialDimIndex(num_dims, data_format, i);
4952       spatial_dims.push_back(spatial_dim);
4953 
4954       // Prepare metadata indexed by spatial_dim for computing pad_before
4955       // and pad_after.
4956       int64_t input_size = input_shape[spatial_dim];
4957       if (input_size == ShapedType::kDynamicSize) return failure();
4958       int64_t output_size = out_backprop_ty.getDimSize(spatial_dim);
4959       if (output_size == ShapedType::kDynamicSize) return failure();
4960       int64_t filter_size = filter_ty.getDimSize(i);
4961       int64_t stride = strides[spatial_dim];
4962       int64_t dilation = dilations[spatial_dim];
4963 
4964       // Compute pad_before and pad_after following the logic from
4965       // ConvBackpropComputeDimensionsV2. (Unfortunately, we cannot call
4966       // the function in question because it doesn't work with dynamic dims).
4967       int64_t padding_before = -1, padding_after = -1;
4968       if (padding == tensorflow::Padding::EXPLICIT) {
4969         padding_before = explicit_paddings[2 * spatial_dim];
4970         padding_after = explicit_paddings[2 * spatial_dim + 1];
4971       }
4972       int64_t expected_output_size = 0;
4973       auto status = GetWindowedOutputSizeVerboseV2(
4974           input_size, filter_size, dilation, stride, padding,
4975           &expected_output_size, &padding_before, &padding_after);
4976       if (!status.ok()) return failure();
4977       if (output_size != expected_output_size) return failure();
4978       int64_t effective_filter_size = (filter_size - 1) * dilation + 1;
4979       int64_t pad_before = effective_filter_size - 1 - padding_before;
4980       int64_t padded_out_size = input_size + effective_filter_size - 1;
4981       int64_t expanded_output_size = (output_size - 1) * stride + 1;
4982       int64_t pad_after = padded_out_size - expanded_output_size - pad_before;
4983 
4984       // Populate metadata for the upcoming mhlo.conv op using the result of
4985       // the computations performed above.
4986       lhs_dilation.push_back(stride);
4987       rhs_dilation.push_back(dilation);
4988       paddings.push_back(pad_before);
4989       paddings.push_back(pad_after);
4990     }
4991 
4992     RankedTensorType paddings_ty = RankedTensorType::get(
4993         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
4994     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
4995 
4996     Value filter = op.filter();
4997 
4998     const int feature_dim =
4999         tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
5000     const int64_t in_depth = *(input_shape.begin() + feature_dim);
5001     if (in_depth == ShapedType::kDynamicSize) return failure();
5002     const int64_t filter_in_depth = filter_shape[num_spatial_dims];
5003     const int64_t feature_group_count = in_depth / filter_in_depth;
5004 
5005     if (feature_group_count != 1) {
5006       // 1. Reshape filter from
5007       //   [H, W, ..., filter_in_depth, out_depth] to
5008       //   [H, W, ..., filter_in_depth, G, out_depth / G].
5009       auto new_shape = llvm::to_vector<6>(filter_shape);
5010       new_shape.back() = feature_group_count;
5011       new_shape.push_back(filter_shape.back() / feature_group_count);
5012       Type filter_element_ty = filter_ty.getElementType();
5013       auto ty = RankedTensorType::get(new_shape, filter_element_ty);
5014       filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
5015 
5016       // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G].
5017       llvm::SmallVector<int64_t, 6> perm(num_dims + 1);
5018       std::iota(perm.begin(), perm.end(), 0);
5019       std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]);
5020       std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]);
5021       ty = RankedTensorType::get(new_shape, filter_element_ty);
5022       filter = rewriter.create<TransposeOp>(
5023           op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter));
5024 
5025       // 3. Reshape to [H, W, ..., in_depth, out_depth / G].
5026       new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1];
5027       new_shape[num_spatial_dims + 1] = new_shape.back();
5028       new_shape.pop_back();
5029       ty = RankedTensorType::get(new_shape, filter_element_ty);
5030       filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
5031     }
5032 
5033     SmallVector<int64_t, 4> kernel_spatial_dims;
5034     kernel_spatial_dims.resize(num_spatial_dims);
5035     std::iota(kernel_spatial_dims.begin(), kernel_spatial_dims.end(), 0);
5036 
5037     // Mirror the filter in the spatial dimensions.
5038     filter = rewriter.create<ReverseOp>(
5039         op.getLoc(), filter,
5040         GetI64ElementsAttr(kernel_spatial_dims, &rewriter));
5041 
5042     // activation gradients
5043     //   = gradients (with padding and dilation) <conv> mirrored_weights
5044     Value result = rewriter.create<ConvolutionOp>(
5045         op.getLoc(), op.getType(), op.out_backprop(), filter,
5046         /*window_strides=*/
5047         GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
5048                                    &rewriter),
5049         /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
5050         GetI64ElementsAttr(rhs_dilation, &rewriter),
5051         /*window_reversal=*/nullptr,
5052         ConvDimensionNumbersAttr::get(
5053             rewriter.getContext(),
5054             /*input_batch_dimension=*/batch_dim,
5055             /*input_feature_dimension=*/feature_dim,
5056             /*input_spatial_dimensions=*/spatial_dims,
5057             // TF filter shape is [ H, W, ..., inC, outC ]
5058             // Transpose the input and output features for computing the
5059             // gradient.
5060             /*kernel_input_feature_dimension=*/
5061             num_spatial_dims + 1,
5062             /*kernel_output_feature_dimension=*/
5063             num_spatial_dims,
5064             /*kernel_spatial_dimensions=*/kernel_spatial_dims,
5065             /*output_batch_dimension=*/batch_dim,
5066             /*output_feature_dimension=*/feature_dim,
5067             /*output_spatial_dimensions=*/spatial_dims),
5068         rewriter.getI64IntegerAttr(feature_group_count),
5069         /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
5070         /*precision_config=*/ArrayAttr());
5071 
5072     rewriter.replaceOp(op, {result});
5073 
5074     return success();
5075   }
5076 };
5077 
5078 using ConvertConv2DBackpropInputOp =
5079     ConvertConvBackpropInputOp<TF::Conv2DBackpropInputOp,
5080                                /*num_spatial_dims=*/2>;
5081 using ConvertConv3DBackpropInputOp =
5082     ConvertConvBackpropInputOp<TF::Conv3DBackpropInputV2Op,
5083                                /*num_spatial_dims=*/3>;
5084 
5085 // Converts tf.Conv?DBackpropFilterOp into:
5086 //   %result = "mhlo.convolution"(%input, %out_backprop)
5087 template <typename OpTy, int num_spatial_dims>
5088 class ConvertConvBackpropFilterOp : public OpRewritePattern<OpTy> {
5089  public:
5090   using OpRewritePattern<OpTy>::OpRewritePattern;
5091 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5092   LogicalResult matchAndRewrite(OpTy op,
5093                                 PatternRewriter &rewriter) const override {
5094     // Unpack all of the attributes.
5095     tensorflow::TensorFormat data_format;
5096     if (!FormatFromString(op.data_format().str(), &data_format))
5097       return op.emitOpError("invalid data format");
5098 
5099     tensorflow::Padding padding;
5100     if (!GetPaddingFromString(op.padding().str(), &padding).ok())
5101       return failure();
5102 
5103     auto out_backprop_ty =
5104         op.out_backprop().getType().template dyn_cast<RankedTensorType>();
5105     auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
5106 
5107     for (RankedTensorType ty : {out_backprop_ty, input_ty})
5108       if (!ty || !ty.hasStaticShape()) return failure();
5109 
5110     ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
5111     ArrayRef<int64_t> input_shape = input_ty.getShape();
5112 
5113     DenseIntElementsAttr filter_shape_attr;
5114     if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) ||
5115         filter_shape_attr.getType().getRank() != 1)
5116       return failure();
5117 
5118     auto dilations_attr = GetI64ElementsAttr(op.dilations());
5119     std::vector<int> dilations{
5120         dilations_attr.template getValues<int64_t>().begin(),
5121         dilations_attr.template getValues<int64_t>().end()};
5122     auto strides_attr = GetI64ElementsAttr(op.strides());
5123     std::vector<tensorflow::int32> strides{
5124         strides_attr.template getValues<int64_t>().begin(),
5125         strides_attr.template getValues<int64_t>().end()};
5126 
5127     std::vector<int64_t> explicit_paddings;
5128     if (padding == tensorflow::Padding::EXPLICIT) {
5129       // EXPLICIT padding mode and the associated attribute is limited to
5130       // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the
5131       // op.explicit_paddings() attribute getter.
5132       ArrayRef<Attribute> explicit_paddings_attr =
5133           op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
5134       explicit_paddings.reserve(explicit_paddings_attr.size());
5135       for (Attribute explicit_padding : explicit_paddings_attr)
5136         explicit_paddings.push_back(
5137             explicit_padding.cast<IntegerAttr>().getInt());
5138     }
5139 
5140     constexpr int num_dims = num_spatial_dims + 2;
5141     auto filter_shape = filter_shape_attr.getValues<int32_t>();
5142 
5143     // Reuse dimension computation logic from conv_grad_shape_utils.cc.
5144     tensorflow::ConvBackpropDimensions dims;
5145     if (!tensorflow::ConvBackpropComputeDimensionsV2(
5146              /*label=*/"", num_spatial_dims,
5147              ToTensorShape<int64_t, num_dims>(input_shape),
5148              ToTensorShape<int32_t, num_dims>(filter_shape),
5149              ToTensorShape<int64_t, num_dims>(out_backprop_shape), dilations,
5150              strides, padding, explicit_paddings, data_format, &dims)
5151              .ok()) {
5152       return failure();
5153     }
5154 
5155     // The activations (inputs) form the LHS of the convolution.
5156     // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
5157     // For the gradient computation, we need to:
5158     // 1. In the case of group convolution, move the num_groups dimension before
5159     // the batch dimension
5160     // 2. Swap the roles of the batch and feature dimensions.
5161     const int feature_dim =
5162         tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
5163     const int64_t in_depth = input_shape[feature_dim];
5164     const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims);
5165     const int64_t batch_group_count = in_depth / filter_in_depth;
5166 
5167     // Compute ConvDimensionNumbers, dilation, and padding.
5168     SmallVector<int64_t, num_spatial_dims> spatial_dims;
5169     SmallVector<int64_t, num_spatial_dims> kernel_spatial_dims;
5170     SmallVector<int64_t, num_spatial_dims> rhs_dilation;
5171     SmallVector<int64_t, num_spatial_dims * 2> paddings;
5172     SmallVector<int64_t, num_spatial_dims> window_strides;
5173 
5174     // The filter gradients are computed by a convolution of the input
5175     // activations and the output gradients, with some appropriate padding.
5176     // See the comment at the top of conv_grad_ops.h for details.
5177 
5178     for (int i : llvm::seq<int>(0, num_spatial_dims)) {
5179       const int64_t dim =
5180           tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
5181       kernel_spatial_dims.push_back(dim);
5182       // Besides padding the input, we will also expand output_rows to
5183       //    expanded_out_rows = (output_rows - 1) * stride + 1
5184       // with zeros in between:
5185       //
5186       //      a . . . b . . . c . . . d . . . e
5187       //
5188       // This is done by specifying the window dilation factors in the
5189       // convolution HLO below.
5190       const auto &spatial_dim_i = dims.spatial_dims[i];
5191       rhs_dilation.push_back(spatial_dim_i.stride);
5192       window_strides.push_back(dilations[dim]);
5193 
5194       // We will also need to pad the input with zeros such that after the
5195       // convolution, we get the right size for the filter.
5196       // The padded_in_rows should be such that when we convolve this with the
5197       // expanded_out_rows as a filter, we should get filter_rows back.
5198 
5199       const int64_t padded_in_size =
5200           spatial_dim_i.expanded_output_size +
5201           (spatial_dim_i.filter_size - 1) * dilations[dim];
5202 
5203       // However it can be smaller than input_rows: in this
5204       // case it means some of the inputs are not used.
5205       //
5206       // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
5207       //
5208       // INPUT =  [ A  B  C ]
5209       //
5210       // FILTER = [ x y ]
5211       //
5212       // and the output will only have one column: a = A * x + B * y
5213       //
5214       // and input "C" is not used at all.
5215       //
5216       // We apply negative padding in this case.
5217       const int64_t pad_total = padded_in_size - spatial_dim_i.input_size;
5218 
5219       // + For the EXPLICIT padding, we pad the top/left side with the explicit
5220       //   padding and pad the bottom/right side with the remaining space.
5221       // + For the VALID padding, we don't pad anything on the top/left side
5222       //   and pad the bottom/right side with the remaining space.
5223       // + For the SAME padding, we pad top/left side the same as bottom/right
5224       //   side.
5225       //
5226       // In addition, if the padded input size is smaller than the input size,
5227       // we need to ignore some training elements of the input. We do this by
5228       // applying negative padding on the right/bottom.
5229       const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT
5230                                      ? explicit_paddings[2 * dim]
5231                                  : padding == tensorflow::Padding::SAME
5232                                      ? std::max<int64_t>(pad_total / 2, 0)
5233                                      : 0;
5234       paddings.push_back(pad_before);
5235       paddings.push_back(pad_total - pad_before);
5236     }
5237 
5238     RankedTensorType paddings_ty = RankedTensorType::get(
5239         {num_spatial_dims, 2}, rewriter.getIntegerType(64));
5240     auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
5241 
5242     SmallVector<int64_t, 4> output_spatial_dimensions;
5243     output_spatial_dimensions.resize(num_spatial_dims);
5244     std::iota(output_spatial_dimensions.begin(),
5245               output_spatial_dimensions.end(), 0);
5246 
5247     const int batch_dim =
5248         tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
5249 
5250     Value result = rewriter.create<ConvolutionOp>(
5251         op.getLoc(), op.getType(), op.input(), op.out_backprop(),
5252         /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
5253         /*padding=*/paddings_attr, /*lhs_dilation=*/
5254         GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
5255                                    &rewriter),
5256         GetI64ElementsAttr(rhs_dilation, &rewriter),
5257         /*window_reversal=*/nullptr,
5258         ConvDimensionNumbersAttr::get(
5259             rewriter.getContext(),
5260             // Swap batch_dim and feature_dim in the activations.
5261             /*input_batch_dimension=*/feature_dim,
5262             /*input_feature_dimension=*/batch_dim,
5263             /*input_spatial_dimensions=*/kernel_spatial_dims,
5264             // The gradients become the RHS of the convolution.
5265             // The gradients have shape [batch, out_rows, out_cols, ...,
5266             // out_depth] where the batch becomes the input feature for the
5267             // convolution.
5268             /*kernel_input_feature_dimension=*/batch_dim,
5269             /*kernel_output_feature_dimension=*/feature_dim,
5270             /*kernel_spatial_dimensions=*/kernel_spatial_dims,
5271             /*output_batch_dimension=*/num_spatial_dims,
5272             /*output_feature_dimension=*/num_spatial_dims + 1,
5273             /*output_spatial_dimensions=*/output_spatial_dimensions),
5274         /*feature_group_count=*/rewriter.getI64IntegerAttr(1),
5275         rewriter.getI64IntegerAttr(batch_group_count),
5276         /*precision_config=*/ArrayAttr());
5277 
5278     rewriter.replaceOp(op, {result});
5279 
5280     return success();
5281   }
5282 };
5283 
5284 using ConvertConv2DBackpropFilterOp =
5285     ConvertConvBackpropFilterOp<TF::Conv2DBackpropFilterOp,
5286                                 /*num_spatial_dims=*/2>;
5287 using ConvertConv3DBackpropFilterOp =
5288     ConvertConvBackpropFilterOp<TF::Conv3DBackpropFilterV2Op,
5289                                 /*num_spatial_dims=*/3>;
5290 
5291 class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
5292  public:
5293   using OpRewritePattern::OpRewritePattern;
5294 
matchAndRewrite(TF::OneHotOp op,PatternRewriter & rewriter) const5295   LogicalResult matchAndRewrite(TF::OneHotOp op,
5296                                 PatternRewriter &rewriter) const override {
5297     auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
5298     if (!indices_ty || !indices_ty.hasStaticShape()) return failure();
5299     ArrayRef<int64_t> indices_shape = indices_ty.getShape();
5300     Type element_type = indices_ty.getElementType();
5301 
5302     DenseIntElementsAttr depth_attr;
5303     if (!matchPattern(op.depth(), m_Constant(&depth_attr))) {
5304       return failure();
5305     }
5306 
5307     int64_t depth = depth_attr.getValues<APInt>()[0].getSExtValue();
5308     int64_t axis = op.axis();
5309     if (axis == -1) axis = indices_shape.size();
5310 
5311     llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
5312     std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
5313     std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
5314 
5315     llvm::SmallVector<int64_t, 4> output_dims =
5316         llvm::to_vector<4>(indices_shape);
5317     output_dims.insert(output_dims.begin() + axis, depth);
5318 
5319     Location loc = op.getLoc();
5320 
5321     // The iota result is the effective output shape of the computation,
5322     // and indices must be broadcast into it. At this point, this computation
5323     // would need to be reworked quite a bit to support dynamic shapes, so
5324     // just using static broadcasting.
5325     auto index_type = RankedTensorType::get(output_dims, element_type);
5326     auto iota = rewriter.create<IotaOp>(
5327         loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis));
5328     auto broadcast_indices = rewriter.create<BroadcastInDimOp>(
5329         loc, index_type, op.indices(),
5330         GetI64ElementsAttr(broadcast_dims, &rewriter));
5331 
5332     Value compare = rewriter.create<mhlo::CompareOp>(
5333         loc, broadcast_indices, iota, ComparisonDirection::EQ);
5334     Value on_value = rewriter.create<BroadcastOp>(
5335         loc, op.getType(), op.on_value(),
5336         GetI64ElementsAttr(output_dims, &rewriter));
5337     Value off_value = rewriter.create<BroadcastOp>(
5338         loc, op.getType(), op.off_value(),
5339         GetI64ElementsAttr(output_dims, &rewriter));
5340     Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
5341                                              on_value, off_value);
5342 
5343     rewriter.replaceOp(op, {result});
5344 
5345     return success();
5346   }
5347 };
5348 
5349 // Converts InfeedDequeueTuple to XLA HLO create_token, infeed and
5350 // get_tuple_element ops.
5351 //
5352 // All HLO infeed ops expect a HLO token type operand and produce a tuple
5353 // containing a token. This HLO token type is used to order multiple infeed
5354 // operations within a computation. The token type can come from other
5355 // infeed/outfeed/send/recv ops or can be generated using create_token op with
5356 // no operands. Here we emit a create_token op to generate the token type
5357 // operand of infeed. The mhlo.InfeedOp can produce multiple results and later
5358 // will be exported to XLA infeed op with single tuple return type.
5359 //
5360 // For example the following IR:
5361 // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
5362 //
5363 // would be lowered to
5364 //
5365 // %token = "mhlo.create_token"() : () -> !mhlo.token
5366 // %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} :
5367 //      (!mhlo.token) -> tensor<3xi32>, tensor<4xf32>, !mhlo.token>
5368 //
5369 class ConvertInfeedDequeueTupleOp
5370     : public OpRewritePattern<TF::InfeedDequeueTupleOp> {
5371  public:
5372   using OpRewritePattern::OpRewritePattern;
5373 
matchAndRewrite(TF::InfeedDequeueTupleOp op,PatternRewriter & rewriter) const5374   LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op,
5375                                 PatternRewriter &rewriter) const override {
5376     SmallVector<Type> result_types;
5377     result_types.reserve(op.outputs().size() + 1);
5378     for (const auto &output : op.outputs()) {
5379       Type ty = output.getType();
5380       if (auto tensor_ty = ty.dyn_cast<RankedTensorType>()) {
5381         if (!tensor_ty.hasStaticShape()) return failure();
5382       }
5383       result_types.push_back(ty);
5384     }
5385 
5386     // Infeed takes a single token operand. Generate the token using
5387     // create_token op to pass to the infeed op.
5388     auto token = rewriter.create<CreateTokenOp>(
5389         op.getLoc(), mhlo::TokenType::get(rewriter.getContext()));
5390 
5391     result_types.push_back(token.getType());
5392 
5393     ArrayAttr layout;  // filled in during the xla-adjust-layout pass
5394     auto data_and_token =
5395         rewriter.create<InfeedOp>(op.getLoc(), result_types, token,
5396                                   /*infeed_config=*/rewriter.getStringAttr(""),
5397                                   /*layout=*/layout);
5398 
5399     result_types.pop_back();  // remove the token type.
5400 
5401     if (op._XlaSharding().has_value()) {
5402       // _XlaSharding attribute in TF is a serialized string of the OpSharding
5403       // proto, so convert to a text form here.
5404       ::xla::OpSharding sharding_proto;
5405       if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
5406         return failure();
5407 
5408       // Token is a control signal and not a real data, so arbitrarily assign
5409       // the token to device 0.
5410       if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
5411         *sharding_proto.add_tuple_shardings() =
5412             ::xla::sharding_builder::AssignDevice(0);
5413         data_and_token->setAttr(
5414             kShardingAttr,
5415             rewriter.getStringAttr(sharding_proto.SerializeAsString()));
5416       } else {
5417         data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr());
5418       }
5419     }
5420 
5421     if (op->hasAttr("layouts")) {
5422       // Append a UnitAttr for the "token" operand of the mhlo.infeed op here to
5423       // avoid compilation failure when exporting "layouts" attribute of the
5424       // corresponding InfeedDequeueTupleOp to a graph node.
5425       data_and_token->setAttr("layout", op->getAttr("layouts"));
5426     }
5427     llvm::SmallVector<Value> results;
5428     results.reserve(result_types.size());
5429     for (auto idx_and_type : llvm::enumerate(result_types)) {
5430       results.push_back(data_and_token.getResult(idx_and_type.index()));
5431     }
5432     rewriter.replaceOp(op, ValueRange(results));
5433     return success();
5434   }
5435 };
5436 
5437 // Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed
5438 // ops.
5439 //
5440 // XLA HLO outfeed op expects a token, which we generate by emitting an
5441 // create_token op.
5442 //
5443 // For example the following IR:
5444 // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
5445 //      ()
5446 //
5447 // would be lowered to
5448 //
5449 // %token = "mhlo.create_token"() : () -> !mhlo.token
5450 // %outfeed_token = "mhlo.outfeed"(%val_1, %val_2, %token) {outfeed_config = ""}
5451 // :
5452 //      (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token
5453 //
5454 class ConvertOutfeedEnqueueTupleOp
5455     : public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
5456  public:
5457   using OpRewritePattern::OpRewritePattern;
5458 
matchAndRewrite(TF::OutfeedEnqueueTupleOp op,PatternRewriter & rewriter) const5459   LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
5460                                 PatternRewriter &rewriter) const override {
5461     auto token_type = mhlo::TokenType::get(rewriter.getContext());
5462     auto token = rewriter.create<CreateTokenOp>(op.getLoc(), token_type);
5463 
5464     rewriter.create<OutfeedOp>(op.getLoc(), token_type, op.inputs(), token,
5465                                /*outfeed_config=*/rewriter.getStringAttr(""));
5466     rewriter.eraseOp(op);
5467     return success();
5468   }
5469 };
5470 
5471 // Converts tf.TopKV2 to chlo.top_k.
5472 class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
5473  public:
5474   using OpRewritePattern::OpRewritePattern;
5475 
matchAndRewrite(TF::TopKV2Op op,PatternRewriter & rewriter) const5476   LogicalResult matchAndRewrite(TF::TopKV2Op op,
5477                                 PatternRewriter &rewriter) const override {
5478     // We can only match when the `k` operand is a constant scalar.
5479     DenseIntElementsAttr k_attr;
5480     if (!matchPattern(op.k(), m_Constant(&k_attr))) return failure();
5481     int64_t k = (*k_attr.begin()).getSExtValue();
5482 
5483     TensorType input_type = op.input().getType().cast<TensorType>();
5484     if (!input_type.hasRank()) return failure();
5485     int64_t input_rank = input_type.getRank();
5486     int64_t last_dim_index = input_rank - 1;
5487     int64_t last_dim_size = input_type.getDimSize(last_dim_index);
5488     if (last_dim_size == ShapedType::kDynamicSize) return failure();
5489 
5490     rewriter.replaceOpWithNewOp<chlo::TopKOp>(op, op.input(), k);
5491     return success();
5492   }
5493 };
5494 
5495 // Converts tf.Unpack to a series of XLA HLO slice ops.
5496 //
5497 // Each slice takes one element along the dimension to unpack and takes the full
5498 // range for all other dimensions. Each slice is then reshaped to drop the
5499 // dimension to unpack (which is always of size 1).
5500 // TODO(antiagainst): consider changing this into a TF internal lowering pass.
5501 class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
5502  public:
5503   using OpRewritePattern::OpRewritePattern;
5504 
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const5505   LogicalResult matchAndRewrite(TF::UnpackOp op,
5506                                 PatternRewriter &rewriter) const override {
5507     auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
5508     if (!value_type) return failure();
5509 
5510     int64_t value_rank = value_type.getRank();
5511     int64_t axis = op.axis();
5512     if (axis < 0) axis += value_rank;
5513 
5514     // Parameters for constructing each slice.
5515     SmallVector<int64_t, 4> begin_indices(value_rank, 0);
5516     auto end_indices = llvm::to_vector<4>(value_type.getShape());
5517     SmallVector<int64_t, 4> strides(value_rank, 1);
5518 
5519     // All HLO slice+squeeze results used to replace the original tf.Unpack op.
5520     SmallVector<Value, 4> results;
5521     results.reserve(op.getNumResults());
5522 
5523     for (int i = 0, end = op.getNumResults(); i < end; ++i) {
5524       begin_indices[axis] = i;
5525       end_indices[axis] = i + 1;
5526 
5527       auto slice_op = rewriter.create<mhlo::SliceOp>(
5528           op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
5529           GetI64ElementsAttr(end_indices, &rewriter),
5530           GetI64ElementsAttr(strides, &rewriter));
5531       // Reshape to drop the axis dimension.
5532       auto result =
5533           rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
5534                                          rewriter.getI64ArrayAttr(op.axis()));
5535       results.push_back(result);
5536     }
5537 
5538     rewriter.replaceOp(op, results);
5539     return success();
5540   }
5541 };
5542 
5543 // Converts tf.Unpack to a series of XLA HLO Slice ops.
5544 // TODO(disc): To recover static special case's performance with folding and
5545 // canonicalization.
5546 class ConvertUnpackOpDynamic : public OpRewritePattern<TF::UnpackOp> {
5547  public:
5548   using OpRewritePattern::OpRewritePattern;
5549 
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const5550   LogicalResult matchAndRewrite(TF::UnpackOp op,
5551                                 PatternRewriter &rewriter) const override {
5552     auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
5553     if (!value_type) return failure();
5554     // TODO(disc): Remove this constraint once fold and canonicalization
5555     // implemented.
5556     if (value_type.hasStaticShape()) return failure();
5557 
5558     int64_t value_rank = value_type.getRank();
5559     int64_t axis = op.axis();
5560     if (axis < 0) axis += value_rank;
5561     Location loc = op.getLoc();
5562 
5563     auto shape_scalar_type = rewriter.getIntegerType(32);
5564     // Parameters for constructing each slice.
5565     SmallVector<Value, 4> begin_indices, end_indices, strides;
5566     begin_indices.reserve(value_rank);
5567     end_indices.reserve(value_rank);
5568     strides.reserve(value_rank);
5569     // final output shape
5570     SmallVector<Value, 4> shape_values;
5571     shape_values.reserve(value_rank - 1);
5572     // slice shape before reshape, should be like{?, 1, ?, ?} if axis = 1
5573     SmallVector<int64_t, 4> slice_shape(value_rank, ShapedType::kDynamicSize);
5574     for (int64_t dim_idx = 0; dim_idx < value_rank; ++dim_idx) {
5575       int64_t dim_size = value_type.getDimSize(dim_idx);
5576       if (dim_size == ShapedType::kDynamicSize) {
5577         Value dim_i = rewriter.create<arith::IndexCastOp>(
5578             loc, shape_scalar_type,
5579             rewriter.create<tensor::DimOp>(loc, op.getOperand(), dim_idx));
5580         end_indices.push_back(dim_i);
5581         if (dim_idx != axis) {
5582           shape_values.push_back(dim_i);
5583         }
5584       } else {
5585         Value dim_i = rewriter.create<arith::ConstantOp>(
5586             loc, shape_scalar_type,
5587             rewriter.getIntegerAttr(shape_scalar_type, dim_size));
5588         end_indices.push_back(dim_i);
5589         if (dim_idx != axis) {
5590           shape_values.push_back(dim_i);
5591           slice_shape[dim_idx] = dim_size;
5592         } else {
5593           slice_shape[dim_idx] = 1;
5594         }
5595       }
5596       begin_indices.push_back(
5597           rewriter.create<arith::ConstantIntOp>(loc, 0, 32));
5598       strides.push_back(rewriter.create<arith::ConstantIntOp>(loc, 1, 32));
5599     }
5600 
5601     SmallVector<Value, 4> results;
5602     results.reserve(op.getNumResults());
5603     Type i32_ty = rewriter.getI32Type();
5604     for (int64_t i = 0; i < op.getNumResults(); ++i) {
5605       begin_indices[axis] = rewriter.create<arith::ConstantIntOp>(loc, i, 32);
5606       end_indices[axis] = rewriter.create<arith::ConstantIntOp>(loc, i + 1, 32);
5607       Value slice_op = rewriter.create<RealDynamicSliceOp>(
5608           loc, RankedTensorType::get(slice_shape, value_type.getElementType()),
5609           op.value(),
5610           rewriter.create<tensor::FromElementsOp>(
5611               loc,
5612               RankedTensorType::get(
5613                   {static_cast<int64_t>(begin_indices.size())}, i32_ty),
5614               begin_indices),
5615           rewriter.create<tensor::FromElementsOp>(
5616               loc,
5617               RankedTensorType::get({static_cast<int64_t>(end_indices.size())},
5618                                     i32_ty),
5619               end_indices),
5620           rewriter.create<tensor::FromElementsOp>(
5621               loc,
5622               RankedTensorType::get({static_cast<int64_t>(strides.size())},
5623                                     i32_ty),
5624               strides));
5625       // Reshape to drop the axis dimension.
5626       Value new_shape = rewriter.create<tensor::FromElementsOp>(
5627           loc,
5628           RankedTensorType::get({static_cast<int64_t>(shape_values.size())},
5629                                 i32_ty),
5630           shape_values);
5631       Value reshape_op = rewriter.create<DynamicReshapeOp>(loc, op.getType(i),
5632                                                            slice_op, new_shape);
5633       results.push_back(reshape_op);
5634     }
5635 
5636     rewriter.replaceOp(op, results);
5637     return success();
5638   }
5639 };
5640 
5641 // Converts the tf.SigmoidGradOp
5642 // TODO(disc): To recover static special case's performance with folding and
5643 // canonicalization.
5644 class ConvertSigmoidGradOpDynamic : public OpRewritePattern<TF::SigmoidGradOp> {
5645  public:
5646   using OpRewritePattern::OpRewritePattern;
5647 
matchAndRewrite(TF::SigmoidGradOp op,PatternRewriter & rewriter) const5648   LogicalResult matchAndRewrite(TF::SigmoidGradOp op,
5649                                 PatternRewriter &rewriter) const override {
5650     Location loc = op.getLoc();
5651     Value y = op.y();
5652     Value dy = op.dy();
5653     auto tp_y = y.getType().dyn_cast<RankedTensorType>();
5654     auto tp_dy = dy.getType().dyn_cast<RankedTensorType>();
5655     if (!tp_y || !tp_dy) return failure();
5656 
5657     // TODO(disc): Remove this constraint once fold and canonicalization
5658     // implemented.
5659     if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure();
5660 
5661     Attribute attr;
5662     Type elem_tp = tp_y.getElementType();
5663     if (elem_tp.isSignlessInteger()) {
5664       attr = rewriter.getIntegerAttr(elem_tp, 1);
5665     } else {
5666       assert(elem_tp.isa<FloatType>());
5667       attr = rewriter.getFloatAttr(elem_tp, 1);
5668     }
5669     Value one = rewriter.create<mhlo::ConstantOp>(
5670         loc, DenseElementsAttr::get(RankedTensorType::get({}, elem_tp), attr));
5671 
5672     auto v0 = rewriter.create<chlo::BroadcastMulOp>(
5673         loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y));
5674     auto v1 = rewriter.create<chlo::BroadcastSubOp>(
5675         loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y));
5676     auto result = rewriter.create<chlo::BroadcastMulOp>(
5677         loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1));
5678 
5679     rewriter.replaceOp(op, result.getOperation()->getResults());
5680     return success();
5681   }
5682 };
5683 
5684 // Converts TF unsorted segment reduction ops to XLA HLO scatter op.
5685 //
5686 // TF unsorted segment reduction op peforms the following calculation:
5687 //
5688 // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's  shape is
5689 // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's
5690 // shape, so we can have data's shape represented as [SI0, SI1, ..., SIm,
5691 // Dm+1, ..., Dn]. Then
5692 //   output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] =
5693 //      <ReductionOp> over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in]
5694 // where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN).
5695 //
5696 // The op will be translated to XLA HLO scatter with the following parameters:
5697 // * Update window dims is [segment_id_rank, data_rank).
5698 // * Inserted window dims is {0}.
5699 // * Scatter dims to operand dims mapping is {0}.
5700 // * Index vector dim is segment_id_rank.
5701 template <typename ConcreteClass, typename OpTy, typename ReductionOp>
5702 class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
5703   using OpRewritePattern<OpTy>::OpRewritePattern;
5704 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5705   LogicalResult matchAndRewrite(OpTy op,
5706                                 PatternRewriter &rewriter) const override {
5707     auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
5708     if (!data_type) return failure();
5709     int64_t data_rank = data_type.getRank();
5710 
5711     auto segment_ids_type =
5712         op.segment_ids().getType().template dyn_cast<RankedTensorType>();
5713     if (!segment_ids_type) return failure();
5714     int64_t segment_ids_rank = segment_ids_type.getRank();
5715 
5716     DenseIntElementsAttr num_segments_attr;
5717     if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr)))
5718       return failure();
5719 
5720     // The final shape for TF unsorted segment reduction op is [num_segments] +
5721     // data_shape[segment_ids_rank:].
5722     SmallVector<int64_t, 4> output_shape;
5723     output_shape.push_back((*num_segments_attr.begin()).getSExtValue());
5724     auto suffix = data_type.getShape().drop_front(segment_ids_rank);
5725     output_shape.append(suffix.begin(), suffix.end());
5726     auto output_type =
5727         RankedTensorType::get(output_shape, data_type.getElementType());
5728 
5729     // Broadcast the initial value for reduction. This will become the
5730     // 'operand' parameter to scatter to for the final scatter op.
5731     Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
5732                                                 op.getLoc(), &rewriter);
5733     auto broadcasted_init = rewriter.create<mhlo::BroadcastOp>(
5734         op.getLoc(), output_type, init,
5735         GetI64ElementsAttr(output_shape, &rewriter));
5736 
5737     // Parameters for the generated scatter op.
5738     SmallVector<int64_t, 1> inserted_window_dims(1, 0);
5739     SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
5740     int64_t index_vector_dim = segment_ids_rank;
5741 
5742     // Put all parameters in a StructAttr.
5743     auto dims_attr = ScatterDimensionNumbersAttr::get(
5744         rewriter.getContext(),
5745         llvm::to_vector<4>(llvm::seq<int64_t>(segment_ids_rank, data_rank)),
5746         inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim);
5747 
5748     auto scatter = rewriter.create<ScatterOp>(
5749         op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)),
5750         op.segment_ids(), op.data(), dims_attr);
5751     BuildReduceBody<ReductionOp>(data_type.getElementType(),
5752                                  &scatter.update_computation(), &rewriter);
5753 
5754     rewriter.replaceOp(op, scatter.getResult(0));
5755     return success();
5756   }
5757 };
5758 
5759 class ConvertUnsortedSegmentMaxOp
5760     : public GenericConvertUnsortedSegmentReductionOp<
5761           ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> {
5762  public:
5763   using GenericConvertUnsortedSegmentReductionOp::
5764       GenericConvertUnsortedSegmentReductionOp;
5765 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)5766   static Value GetInitialValue(Type reduce_element_type, Location loc,
5767                                PatternRewriter *rewriter) {
5768     return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest,
5769                                      rewriter);
5770   }
5771 };
5772 
5773 class ConvertUnsortedSegmentMinOp
5774     : public GenericConvertUnsortedSegmentReductionOp<
5775           ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> {
5776  public:
5777   using GenericConvertUnsortedSegmentReductionOp::
5778       GenericConvertUnsortedSegmentReductionOp;
5779 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)5780   static Value GetInitialValue(Type reduce_element_type, Location loc,
5781                                PatternRewriter *rewriter) {
5782     return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax,
5783                                      rewriter);
5784   }
5785 };
5786 
5787 class ConvertUnsortedSegmentProdOp
5788     : public GenericConvertUnsortedSegmentReductionOp<
5789           ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> {
5790  public:
5791   using GenericConvertUnsortedSegmentReductionOp::
5792       GenericConvertUnsortedSegmentReductionOp;
5793 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)5794   static Value GetInitialValue(Type reduce_element_type, Location loc,
5795                                PatternRewriter *rewriter) {
5796     return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
5797   }
5798 };
5799 
5800 class ConvertUnsortedSegmentSumOp
5801     : public GenericConvertUnsortedSegmentReductionOp<
5802           ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> {
5803  public:
5804   using GenericConvertUnsortedSegmentReductionOp::
5805       GenericConvertUnsortedSegmentReductionOp;
5806 
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)5807   static Value GetInitialValue(Type reduce_element_type, Location loc,
5808                                PatternRewriter *rewriter) {
5809     return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
5810   }
5811 };
5812 
5813 // Converts tf.RandomShuffle op into a series of XLA HLO ops.
5814 //
5815 // tf.RandomShuffle shuffles tensors along the first dimension. If the input
5816 // tensor's rank is 1, then it is translated into HLO sort op(s) according to
5817 // indices randomly generated via HLO rng_uniform ops. Otherwise, it is
5818 // translated into an HLO while op to first emulate shuffling indices using
5819 // HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather
5820 // with the shuffled indices.
5821 class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
5822  public:
5823   using OpRewritePattern::OpRewritePattern;
5824 
matchAndRewrite(TF::RandomShuffleOp op,PatternRewriter & rewriter) const5825   LogicalResult matchAndRewrite(TF::RandomShuffleOp op,
5826                                 PatternRewriter &rewriter) const override {
5827     auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
5828     if (!input_type) return failure();
5829 
5830     int64_t input_rank = input_type.getRank();
5831     int64_t first_dim_size = input_type.getDimSize(0);
5832     if (ShapedType::isDynamic(first_dim_size)) return failure();
5833 
5834     // We are shuffling along the first dimension. If its size is <= 1, then
5835     // shuffling is a no-op.
5836     if (first_dim_size <= 1) {
5837       rewriter.replaceOp(op, op.value());
5838       return success();
5839     }
5840 
5841     // For vectors, shuffle values by sorting instead of the obvious
5842     // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,
5843     // but not easily parallelizable. For a sufficiently parallel architecture,
5844     // it is faster to sort many times, than Fisher-Yates shuffle once.
5845     if (input_rank == 1) {
5846       // Shuffle values by assigning each value a random key and sorting the
5847       // keys. Keys can collide causing detectable patterns in the shuffled
5848       // output. Collisions translates into more ascending sub-sequences in the
5849       // shuffled output than would be expected by chance. To avoid collisions,
5850       // the number of possible key values must be sufficiently large.
5851 
5852       // How are more than 2^32 keys created? In each loop iteration, the
5853       // algorithm sorts by random keys. Conceptually, the earlier iterations
5854       // are sorting on the lower-order bits of larger keys that are never
5855       // actually assembled.
5856 
5857       // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
5858       // the number of possible keys and n is the number of values. If d = n^2,
5859       // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
5860       // as n goes to infinity is zero.
5861 
5862       // This implementation ensures that the key-space is greater than or equal
5863       // to the cube of the number of values. The risk of collisions can be
5864       // further reduced by increasing Exponent at the expense of
5865       // performance.
5866 
5867       // For Exponent = 2, the expected number of collisions per shuffle is
5868       // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
5869       // about 1/2.
5870 
5871       // For Exponent = 3, the expected number of collisions per shuffle is
5872       // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
5873       // about 1/3255.
5874 
5875       // For Exponent = 4, the expected number of collisions per shuffle is
5876       // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
5877       // about 1/132622.
5878       constexpr int exponent = 3;
5879       int64_t num_elements = input_type.getNumElements();
5880       uint32_t u32_max = std::numeric_limits<uint32_t>::max();
5881       int rounds =
5882           std::ceil(exponent * std::log(num_elements) / std::log(u32_max));
5883 
5884       Value current = op.value();
5885       for (int i = 0; i < rounds; ++i) {
5886         auto keys =
5887             CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
5888                                /*upper_limit=*/u32_max, &rewriter);
5889         auto sorted = createSortOp(
5890             &rewriter, op.getLoc(), {keys, current},
5891             {rewriter.getIntegerType(32), input_type.getElementType()},
5892             /*dimension=*/-1, /*is_stable=*/false,
5893             /*direction=*/ComparisonDirection::LT);
5894         current = sorted.getResult(1);
5895       }
5896       rewriter.replaceOp(op, current);
5897       return success();
5898     }
5899 
5900     // The Fisher-Yates algorithm.
5901 
5902     // Generate range(n) as the initial value for the indices to be swapped.
5903     auto indices_type =
5904         RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
5905     Value indices = rewriter.create<mhlo::IotaOp>(
5906         op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0));
5907 
5908     // Generate random numbers to be used as swaps for the indices.
5909     Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0,
5910                                      first_dim_size, &rewriter);
5911 
5912     // While loop body to perform index swaps.
5913     auto swap_body_fn = [&](Location loc, Value i, ArrayRef<Value> old_values,
5914                             SmallVectorImpl<Value> *new_values,
5915                             OpBuilder *builder) {
5916       Value swaps = old_values[0];
5917       Value indices = old_values[1];
5918 
5919       auto scalar_i32_type =
5920           RankedTensorType::get({}, builder->getIntegerType(32));
5921       auto scalar_i64_type =
5922           RankedTensorType::get({}, builder->getIntegerType(64));
5923 
5924       auto scalar_one =
5925           DenseIntElementsAttr::get(scalar_i64_type, ArrayRef<int64_t>(1));
5926 
5927       // We need to swap the indices[i] with indices[swaps[i]]. First get
5928       // these index values.
5929       Value source_index =
5930           builder->create<mhlo::DynamicSliceOp>(loc, indices, i, scalar_one);
5931       Value swap_index = builder->create<mhlo::ReshapeOp>(
5932           loc, scalar_i32_type,
5933           builder->create<mhlo::DynamicSliceOp>(loc, swaps, i, scalar_one));
5934       Value target_index = builder->create<mhlo::DynamicSliceOp>(
5935           loc, indices, swap_index, scalar_one);
5936 
5937       // Then perform the swap.
5938       // indices[i] <- indices[swaps[i]]
5939       indices = builder->create<mhlo::DynamicUpdateSliceOp>(
5940           loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
5941       // indices[swaps[i]] <- indices[i]
5942       indices = builder->create<mhlo::DynamicUpdateSliceOp>(
5943           loc, indices.getType(), indices, source_index,
5944           llvm::makeArrayRef(swap_index));
5945 
5946       // Update new values.
5947       new_values->assign({swaps, indices});
5948     };
5949 
5950     // Create a while op to swap indices.
5951     SmallVector<Value, 2> while_output;
5952     CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices},
5953                   &while_output, &rewriter);
5954     Value swaped_indices = while_output[1];
5955 
5956     // Gather the data using the swapped indices as the shuffled order.
5957     ArrayRef<int64_t> input_shape = input_type.getShape();
5958     SmallVector<int64_t, 4> slice_sizes(input_shape.begin(), input_shape.end());
5959     slice_sizes[0] = 1;
5960     auto dims_attr = GatherDimensionNumbersAttr::get(
5961         rewriter.getContext(),
5962         /*offset_dims=*/llvm::to_vector<4>(llvm::seq<int64_t>(1, input_rank)),
5963         /*collapsed_slice_dims=*/{0},
5964         /*start_index_map=*/{0},
5965         /*index_vector_dim=*/1);
5966     rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
5967         op, op.getType(), op.value(), swaped_indices, dims_attr,
5968         GetI64ElementsAttr(slice_sizes, &rewriter));
5969 
5970     return success();
5971   }
5972 };
5973 
5974 // Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
5975 class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
5976  public:
5977   using OpRewritePattern::OpRewritePattern;
5978 
matchAndRewrite(TF::XlaShardingOp op,PatternRewriter & rewriter) const5979   LogicalResult matchAndRewrite(TF::XlaShardingOp op,
5980                                 PatternRewriter &rewriter) const override {
5981     // TODO(b/148313088): define sharding attribute struct in MLIR intead of
5982     // using a string.
5983     if (!op._XlaSharding().has_value()) return failure();
5984 
5985     NamedAttribute call_target_name = rewriter.getNamedAttr(
5986         "call_target_name", rewriter.getStringAttr("Sharding"));
5987 
5988     auto custom_call = rewriter.create<mhlo::CustomCallOp>(
5989         op.getLoc(), op.getType(), op.input(),
5990         ArrayRef<NamedAttribute>{call_target_name});
5991     custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
5992     rewriter.replaceOp(op, custom_call.getResult(0));
5993 
5994     return success();
5995   }
5996 };
5997 
5998 // Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO.
5999 class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
6000  public:
6001   using OpRewritePattern::OpRewritePattern;
6002 
matchAndRewrite(TF::InplaceUpdateOp op,PatternRewriter & rewriter) const6003   LogicalResult matchAndRewrite(TF::InplaceUpdateOp op,
6004                                 PatternRewriter &rewriter) const override {
6005     auto input = op.x();
6006     auto indices = op.i();
6007     auto updates = op.v();
6008 
6009     // Slice each row of `i` and `v` to perform a separate dynamic-update-slice
6010     // on the contents of `x`.
6011     auto input_type = input.getType().cast<ShapedType>();
6012     auto updates_type = updates.getType().cast<ShapedType>();
6013     auto indices_type = indices.getType().cast<ShapedType>();
6014     if (!input_type.hasRank()) return failure();
6015     if (!updates_type.hasRank() || updates_type.isDynamicDim(0))
6016       return failure();
6017     if (!indices_type.hasStaticShape()) return failure();
6018 
6019     if (indices_type.getRank() != 1) return failure();
6020 
6021     SmallVector<Type, 4> unpacked_indices_type(
6022         indices_type.getDimSize(0),
6023         RankedTensorType::get({}, indices_type.getElementType()));
6024     // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are
6025     // required to have matching types. This rewrite rule creates
6026     // DynamicUpdateSlice ops where the first "start index" is always i32 and
6027     // subsequent ones are constructed based on zero_attr. Thus the type
6028     // for zero_attr needs to be i32 as well.
6029     auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0);
6030     auto unpacked_indices = rewriter.create<TF::UnpackOp>(
6031         op.getLoc(), unpacked_indices_type, indices, zero_attr);
6032 
6033     SmallVector<int64_t, 4> split_updates_shape;
6034     split_updates_shape.append(updates_type.getShape().begin(),
6035                                updates_type.getShape().end());
6036     split_updates_shape.front() = 1;
6037     SmallVector<Type, 4> split_updates_type;
6038     split_updates_type.resize(
6039         updates_type.getShape().front(),
6040         RankedTensorType::get(split_updates_shape,
6041                               updates_type.getElementType()));
6042 
6043     auto cst =
6044         rewriter.create<mhlo::ConstantOp>(op.getLoc(), zero_attr).getResult();
6045     auto split_updates = rewriter.create<TF::SplitOp>(
6046         op.getLoc(), split_updates_type, cst, updates);
6047 
6048     SmallVector<Value, 6> input_indices;
6049     input_indices.resize(input_type.getRank(), cst);
6050 
6051     for (auto pair :
6052          llvm::zip(unpacked_indices.output(), split_updates.output())) {
6053       input_indices.front() = std::get<0>(pair);
6054       input = rewriter.create<mhlo::DynamicUpdateSliceOp>(
6055           op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
6056     }
6057 
6058     rewriter.replaceOp(op, input);
6059     return success();
6060   }
6061 };
6062 
6063 // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO.
6064 class ConvertXlaDynamicUpdateSliceOp
6065     : public OpRewritePattern<TF::XlaDynamicUpdateSliceOp> {
6066  public:
6067   using OpRewritePattern::OpRewritePattern;
6068 
matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,PatternRewriter & rewriter) const6069   LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,
6070                                 PatternRewriter &rewriter) const override {
6071     auto indices_type = op.indices().getType().dyn_cast<RankedTensorType>();
6072     if (!indices_type || !indices_type.hasStaticShape() ||
6073         indices_type.getShape().size() != 1)
6074       return failure();
6075 
6076     SmallVector<Type, 4> unpacked_indices_type(
6077         indices_type.getDimSize(0),
6078         RankedTensorType::get({}, indices_type.getElementType()));
6079     auto unpacked_indices = rewriter.create<TF::UnpackOp>(
6080         op.getLoc(), unpacked_indices_type, op.indices(),
6081         IntegerAttr::get(rewriter.getIntegerType(64), 0));
6082     rewriter.replaceOpWithNewOp<mhlo::DynamicUpdateSliceOp>(
6083         op, op.getType(), op.input(), op.update(), unpacked_indices.output());
6084     return success();
6085   }
6086 };
6087 
6088 // Converts a TF XlaReduceScatter op to ReduceScatter HLO.
6089 class ConvertXlaReduceScatterOp
6090     : public OpRewritePattern<TF::XlaReduceScatterOp> {
6091   using OpRewritePattern::OpRewritePattern;
6092 
matchAndRewrite(TF::XlaReduceScatterOp op,PatternRewriter & rewriter) const6093   LogicalResult matchAndRewrite(TF::XlaReduceScatterOp op,
6094                                 PatternRewriter &rewriter) const override {
6095     DenseIntElementsAttr group_assignment;
6096     if (!matchPattern(op.group_assignment(), m_Constant(&group_assignment)))
6097       return failure();
6098     auto replica_groups =
6099         hlo::convertElementsAttr(group_assignment, rewriter.getIntegerType(64))
6100             .cast<DenseIntElementsAttr>();
6101     if (replica_groups.getType().getRank() != 2) return failure();
6102 
6103     APInt scatter_dimension;
6104     if (!matchPattern(op.scatter_dimension(),
6105                       m_ConstantInt(&scatter_dimension)))
6106       return failure();
6107 
6108     Location loc = op.getLoc();
6109     Type element_type = getElementTypeOrSelf(op.input().getType());
6110 
6111     auto reduce_scatter = rewriter.create<ReduceScatterOp>(
6112         loc, op.getType(), op.input(),
6113         rewriter.getIntegerAttr(rewriter.getIntegerType(64),
6114                                 scatter_dimension.getSExtValue()),
6115         replica_groups, ChannelHandleAttr());
6116     StringRef reduce_op = op.reduce_op();
6117     if (reduce_op == "Add") {
6118       BuildReduceBody<AddOp>(element_type, &reduce_scatter.computation(),
6119                              &rewriter);
6120     } else if (reduce_op == "Mul") {
6121       BuildReduceBody<MulOp>(element_type, &reduce_scatter.computation(),
6122                              &rewriter);
6123     } else if (reduce_op == "Min") {
6124       BuildReduceBody<MinOp>(element_type, &reduce_scatter.computation(),
6125                              &rewriter);
6126     } else if (reduce_op == "Max") {
6127       BuildReduceBody<MaxOp>(element_type, &reduce_scatter.computation(),
6128                              &rewriter);
6129     } else {
6130       // For mean, add replicas in the same group. Then divide the sum by the
6131       // number of replicas in each group below.
6132       assert(reduce_op == "Mean");
6133       BuildReduceBody<AddOp>(element_type, &reduce_scatter.computation(),
6134                              &rewriter);
6135     }
6136     Value result = reduce_scatter.getResult();
6137 
6138     // For mean, divide the merge result by group size.
6139     if (reduce_op == "Mean") {
6140       int64_t replica_group_size = replica_groups.getType().getDimSize(1);
6141       if (replica_group_size == 0) return failure();
6142       auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size,
6143                                           &rewriter);
6144       auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
6145       result = rewriter.create<chlo::BroadcastDivOp>(
6146           loc, result, divisor.getResult(), broadcast_dims);
6147     }
6148 
6149     rewriter.replaceOp(op, {result});
6150     return success();
6151   }
6152 };
6153 
6154 // Converts tf.XlaReduceWindow to mhlo.ReduceWindow
6155 class ConvertXlaReduceWindowOp
6156     : public OpRewritePattern<TF::XlaReduceWindowOp> {
6157   using OpRewritePattern::OpRewritePattern;
6158 
matchAndRewrite(TF::XlaReduceWindowOp op,PatternRewriter & rewriter) const6159   LogicalResult matchAndRewrite(TF::XlaReduceWindowOp op,
6160                                 PatternRewriter &rewriter) const override {
6161     DenseElementsAttr window_dimensions, window_strides, base_dilations,
6162         window_dilations, padding;
6163     if (!(matchPattern(op.window_dimensions(),
6164                        m_Constant(&window_dimensions)) &&
6165           matchPattern(op.window_strides(), m_Constant(&window_strides)) &&
6166           matchPattern(op.base_dilations(), m_Constant(&base_dilations)) &&
6167           matchPattern(op.window_dilations(), m_Constant(&window_dilations)) &&
6168           matchPattern(op.padding(), m_Constant(&padding))))
6169       return failure();
6170 
6171     Location loc = op.getLoc();
6172 
6173     SmallVector<Type> result_types{op.getResult().getType()};
6174     // Create the mhlo.SelectAndScatter op.
6175     auto reduce_window_op = rewriter.create<mhlo::ReduceWindowOp>(
6176         loc, result_types, op.input(), op.init_value(),
6177         hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64))
6178             .cast<DenseIntElementsAttr>(),
6179         hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64))
6180             .cast<DenseIntElementsAttr>(),
6181         hlo::convertElementsAttr(base_dilations, rewriter.getIntegerType(64))
6182             .cast<DenseIntElementsAttr>(),
6183         hlo::convertElementsAttr(window_dilations, rewriter.getIntegerType(64))
6184             .cast<DenseIntElementsAttr>(),
6185         hlo::convertElementsAttr(padding, rewriter.getIntegerType(64))
6186             .cast<DenseIntElementsAttr>());
6187     // Insert a call to the reducer in the region of the mhlo op.
6188     mlir::SymbolRefAttr func = op.computation();
6189     auto func_op = cast<mlir::func::FuncOp>(SymbolTable::lookupSymbolIn(
6190         op->getParentOfType<mlir::ModuleOp>(), func));
6191     auto func_ty = func_op.getFunctionType();
6192     BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_window_op.body());
6193 
6194     rewriter.replaceOp(op, reduce_window_op.getResults());
6195 
6196     return success();
6197   }
6198 };
6199 
6200 // Converts ClipByValue to XLA's clamp operation. Includes the broadcasting
6201 // semantics for static and dynamic cases.
6202 class ConvertClipByValueOp : public OpRewritePattern<TF::ClipByValueOp> {
6203  public:
6204   using OpRewritePattern::OpRewritePattern;
6205 
matchAndRewrite(TF::ClipByValueOp op,PatternRewriter & rewriter) const6206   LogicalResult matchAndRewrite(TF::ClipByValueOp op,
6207                                 PatternRewriter &rewriter) const override {
6208     Value input = op.t();
6209     Value min = op.clip_value_min();
6210     Value max = op.clip_value_max();
6211 
6212     auto input_ty = input.getType().cast<ShapedType>();
6213     auto min_ty = min.getType().cast<ShapedType>();
6214     auto max_ty = max.getType().cast<ShapedType>();
6215 
6216     if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) {
6217       return failure();
6218     }
6219 
6220     auto shape = rewriter.create<TF::ShapeOp>(
6221         op.getLoc(),
6222         RankedTensorType::get({input_ty.getRank()}, rewriter.getI32Type()),
6223         input);
6224 
6225     if (min_ty != input_ty) {
6226       min =
6227           rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, min, shape);
6228     }
6229 
6230     if (max_ty != input_ty) {
6231       max =
6232           rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, max, shape);
6233     }
6234 
6235     rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, input_ty, min, input, max);
6236     return success();
6237   }
6238 };
6239 
6240 // Converts ConstOp to XLA's constant operation and introduces a tensor cast if
6241 // needed.
6242 class ConvertConstOp : public OpRewritePattern<TF::ConstOp> {
6243  public:
6244   using OpRewritePattern::OpRewritePattern;
6245 
matchAndRewrite(TF::ConstOp op,PatternRewriter & rewriter) const6246   LogicalResult matchAndRewrite(TF::ConstOp op,
6247                                 PatternRewriter &rewriter) const override {
6248     // Convert only for valid HLO tensors.
6249     auto ty = op.getType().dyn_cast<TensorType>();
6250     if (!ty || !ty.getElementType().isa<FloatType, IntegerType, ComplexType>())
6251       return failure();
6252 
6253     Location loc = op.getLoc();
6254     Value result = rewriter.create<mhlo::ConstantOp>(loc, op.value());
6255     if (result.getType() != op.getType())
6256       result = rewriter.create<tensor::CastOp>(loc, op.getType(), result);
6257     rewriter.replaceOp(op, result);
6258     return success();
6259   }
6260 };
6261 
6262 // Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
6263 // setting appropriate window dimensions, with the given aggregation op as the
6264 // reduction function. The input tensor needs to have a static shape, and 'axis'
6265 // must be const. The TableGen pattern is not used for this rewrite because it
6266 // involves regions.
6267 template <typename OpT, typename AggregationOp>
6268 class ConvertCumOp : public OpRewritePattern<OpT> {
6269   using OpRewritePattern<OpT>::OpRewritePattern;
6270 
matchAndRewrite(OpT op,PatternRewriter & rewriter) const6271   LogicalResult matchAndRewrite(OpT op,
6272                                 PatternRewriter &rewriter) const override {
6273     auto input = op.x();
6274     auto input_type = input.getType().template dyn_cast<ShapedType>();
6275     if (!input_type || !input_type.hasStaticShape()) {
6276       return failure();
6277     }
6278 
6279     ArrayRef<int64_t> input_shape = input_type.getShape();
6280     int64_t rank = input_shape.size();
6281 
6282     // We can only match when the axis is a constant scalar.
6283     DenseIntElementsAttr axis_attr;
6284     if (!matchPattern(op.axis(), m_Constant(&axis_attr))) {
6285       return failure();
6286     }
6287 
6288     // Get the dimension to apply the reduction on, and offset properly if it is
6289     // negative.
6290     int64_t axis = (*axis_attr.begin()).getSExtValue();
6291     if (axis < 0) {
6292       axis += rank;
6293     }
6294 
6295     // If we're supposed to sum things up in the reverse direction, we reverse
6296     // the input and then later reverse the output.
6297     if (op.reverse()) {
6298       llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
6299       input = rewriter.create<ReverseOp>(
6300           op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter));
6301     }
6302 
6303     // Convert if we need to enlarge the element type's bitwidth to avoid
6304     // precision loss.
6305     Type input_element_type = input_type.getElementType();
6306 
6307     // TODO(hinsu): Handle complex element types.
6308     if (!input_element_type.isIntOrFloat()) return failure();
6309 
6310     Type sum_element_type = GetSumAccumulationType(input_element_type);
6311     input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
6312 
6313     SmallVector<int64_t, 4> window_dims(rank, 1);
6314     SmallVector<int64_t, 4> window_strides(rank, 1);
6315     window_dims[axis] = input_shape[axis];
6316 
6317     SmallVector<int64_t, 8> paddings(rank * 2, 0);
6318     paddings[axis * 2] =
6319         std::max(input_shape[axis] - 1, static_cast<int64_t>(0));
6320     auto paddings_attr = DenseIntElementsAttr::get(
6321         RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
6322         paddings);
6323 
6324     int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
6325     Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
6326                                       &rewriter);
6327 
6328     auto reduce = rewriter.create<ReduceWindowOp>(
6329         op.getLoc(), input.getType(), input, init,
6330         GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)),
6331         GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
6332         /*base_dilations=*/DenseIntElementsAttr(),
6333         /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
6334     BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
6335     Value result = reduce.getResult(0);
6336 
6337     if (op.exclusive()) {
6338       // In "exclusive" operation, the output will start with the "init" (0)
6339       // values. There is no way to express that as a ReduceWindowOp, so run the
6340       // normal operation, and then use a PadOp to add the 0 "column" on the
6341       // left and cut away the last column on the right.
6342       llvm::SmallVector<int64_t, 4> low_padding(rank, 0);
6343       llvm::SmallVector<int64_t, 4> high_padding(rank, 0);
6344       llvm::SmallVector<int64_t, 4> interior_padding(rank, 0);
6345       low_padding[axis] = 1;
6346       high_padding[axis] = -1;
6347       result = rewriter.create<PadOp>(
6348           op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter),
6349           GetI64ElementsAttr(high_padding, &rewriter),
6350           GetI64ElementsAttr(interior_padding, &rewriter));
6351     }
6352 
6353     // Convert back if we enlarged the element type's bitwidth.
6354     result =
6355         rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
6356 
6357     if (op.reverse()) {
6358       llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
6359       result = rewriter.create<ReverseOp>(
6360           op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter));
6361     }
6362 
6363     rewriter.replaceOp(op, result);
6364     return success();
6365   }
6366 };
6367 
6368 using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
6369 using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
6370 
6371 // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
6372 // dialect lowerings. This involves extracting the shape type, extracting and
6373 // converting each dimension to a known integer type, and repacking into a final
6374 // tensor.
6375 class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
6376  public:
6377   using OpRewritePattern::OpRewritePattern;
6378 
matchAndRewrite(TF::ShapeOp op,PatternRewriter & rewriter) const6379   LogicalResult matchAndRewrite(TF::ShapeOp op,
6380                                 PatternRewriter &rewriter) const override {
6381     Value input = op.input();
6382 
6383     auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
6384     if (!result_ty) {
6385       return failure();
6386     }
6387 
6388     auto index_tensor =
6389         RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
6390     auto shape_op =
6391         rewriter.create<shape::ShapeOfOp>(op.getLoc(), index_tensor, input);
6392     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, result_ty, shape_op);
6393     return success();
6394   }
6395 };
6396 
6397 class ConvertDynamicExpandDimsOp : public OpRewritePattern<TF::ExpandDimsOp> {
6398  public:
6399   using OpRewritePattern::OpRewritePattern;
6400 
matchAndRewrite(TF::ExpandDimsOp op,PatternRewriter & rewriter) const6401   LogicalResult matchAndRewrite(TF::ExpandDimsOp op,
6402                                 PatternRewriter &rewriter) const override {
6403     auto input = op.input();
6404     auto input_ty = input.getType().cast<ShapedType>();
6405     auto result_ty = op.getType().cast<ShapedType>();
6406     if (!result_ty.hasRank() || !input_ty.hasRank() ||
6407         result_ty.hasStaticShape()) {
6408       return failure();
6409     }
6410 
6411     DenseIntElementsAttr expand_dims_attr;
6412     if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) {
6413       return failure();
6414     }
6415 
6416     auto shape = rewriter.create<shape::ShapeOfOp>(
6417         op.getLoc(),
6418         RankedTensorType::get({input_ty.getRank()}, rewriter.getIndexType()),
6419         input);
6420     auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getValues<APInt>());
6421 
6422     llvm::SmallVector<Value, 4> dims;
6423     dims.resize(result_ty.getRank());
6424 
6425     auto inserted_dim = expand_dims[0].getSExtValue();
6426 
6427     // Handle the negative value use case.
6428     if (inserted_dim < 0) {
6429       inserted_dim += result_ty.getRank();
6430       // This means the value is completely incorrect, just return.
6431       if (inserted_dim < 0) {
6432         return failure();
6433       }
6434     }
6435 
6436     dims[inserted_dim] =
6437         rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
6438 
6439     for (int i = 0; i < dims.size() - 1; i++) {
6440       // Add the extracted dim.
6441       Value index = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), i);
6442       Value dim = rewriter.create<tensor::ExtractOp>(op.getLoc(), shape, index);
6443       dims[i >= inserted_dim ? i + 1 : i] = dim;
6444     }
6445 
6446     auto from_extents =
6447         rewriter.create<tensor::FromElementsOp>(op.getLoc(), dims);
6448     rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, input,
6449                                                         from_extents);
6450     return success();
6451   }
6452 };
6453 
6454 class ConvertDynamicSqueezeOp : public OpRewritePattern<TF::SqueezeOp> {
6455  public:
6456   using OpRewritePattern::OpRewritePattern;
6457 
matchAndRewrite(TF::SqueezeOp op,PatternRewriter & rewriter) const6458   LogicalResult matchAndRewrite(TF::SqueezeOp op,
6459                                 PatternRewriter &rewriter) const override {
6460     auto input = op.input();
6461     auto input_ty = input.getType().cast<ShapedType>();
6462     auto result_ty = op.getType().cast<ShapedType>();
6463     if (!result_ty.hasRank() || !input_ty.hasRank() ||
6464         result_ty.hasStaticShape()) {
6465       return failure();
6466     }
6467 
6468     // The fully dynamic case is unsupported.
6469     if (op.squeeze_dims().empty()) {
6470       return failure();
6471     }
6472 
6473     SmallVector<int64_t> squeeze_dims;
6474     int64_t input_rank = input_ty.getRank();
6475     for (const auto &squeeze_dim_apint :
6476          op.squeeze_dims().getAsValueRange<IntegerAttr>()) {
6477       int64_t squeeze_dim = squeeze_dim_apint.getSExtValue();
6478       // Handle negative inputs.
6479       if (squeeze_dim < 0) squeeze_dim += input_rank;
6480       assert(squeeze_dim >= 0 && squeeze_dim < input_rank &&
6481              "squeeze dim out of bounds");
6482 
6483       squeeze_dims.push_back(squeeze_dim);
6484     }
6485 
6486     // Collect the unsqueezed dimensions.
6487     llvm::SmallVector<Value> dims;
6488     for (int64_t i = 0; i != input_rank; ++i) {
6489       if (llvm::is_contained(squeeze_dims, i)) continue;
6490       dims.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), input, i));
6491     }
6492 
6493     auto from_extents =
6494         rewriter.create<tensor::FromElementsOp>(op.getLoc(), dims);
6495     // chlo::DynamicReshapeOp checks if the reshape is legal and will fail if
6496     // any non-1 dimension is squeezed.
6497     rewriter.replaceOpWithNewOp<chlo::DynamicReshapeOp>(op, result_ty, input,
6498                                                         from_extents);
6499     return success();
6500   }
6501 };
6502 
6503 // Converts a TF QR op to HLO.
6504 class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
6505  public:
6506   using OpRewritePattern::OpRewritePattern;
6507 
matchAndRewrite(TF::QrOp op,PatternRewriter & rewriter) const6508   LogicalResult matchAndRewrite(TF::QrOp op,
6509                                 PatternRewriter &rewriter) const override {
6510     // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van
6511     // Loan. def qr_blocked(a, block_size):
6512     //   m = a.shape[0]
6513     //   n = a.shape[1]
6514     //   q = np.eye(m)
6515     //   for i in xrange(0, min(m, n), block_size):
6516     //     k = min(block_size, min(m, n) - s)
6517     //     (a, vs, taus) = qr(a[i:, i:i+k])
6518     //     y = vs
6519     //     w = ComputeWYRepresentation(vs, taus, m-i, k)
6520     //     a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
6521     //     q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
6522     //   return (q, a)
6523     auto type = op.input().getType().dyn_cast<RankedTensorType>();
6524     if (!type || !type.hasStaticShape()) return failure();
6525     // The block size is chosen to match old bridge lowering.
6526     constexpr int64_t kBlockSize = 128;
6527     Value a = op.input();
6528     int64_t m = type.getDimSize(type.getRank() - 2);
6529     int64_t n = type.getDimSize(type.getRank() - 1);
6530     int64_t p = std::min(m, n);
6531     auto batch_dims = type.getShape().drop_back(2);
6532     auto iota_type = RankedTensorType::get({m, m}, rewriter.getIntegerType(32));
6533     auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
6534                                          rewriter.getI64IntegerAttr(0));
6535     auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
6536                                          rewriter.getI64IntegerAttr(1));
6537     Value compare = rewriter.create<CompareOp>(op.getLoc(), iota0, iota1,
6538                                                ComparisonDirection::EQ);
6539     Value identity_matrix =
6540         rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
6541     auto q_shape = llvm::to_vector<4>(type.getShape());
6542     q_shape.back() = m;
6543     Value q =
6544         rewriter.create<BroadcastOp>(op.getLoc(), identity_matrix,
6545                                      GetI64ElementsAttr(batch_dims, &rewriter));
6546     auto precision_config = rewriter.getArrayAttr(
6547         {PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST),
6548          PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST)});
6549     for (int64_t i = 0; i < p; i += kBlockSize) {
6550       int64_t k = std::min(kBlockSize, p - i);
6551       auto a_block =
6552           SliceInMinorDims(op.getLoc(), a, {i, i}, {m, i + k}, &rewriter);
6553       Value r_block;
6554       Value taus;
6555       Value vs;
6556       QRBlock(op.getLoc(), a_block, &r_block, &taus, &vs, &rewriter);
6557       a = UpdateSliceInMinorDims(op.getLoc(), a, r_block, {i, i}, &rewriter);
6558 
6559       // Compute the I-WY block representation of a product of Householder
6560       // matrices.
6561       Value w =
6562           ComputeWYRepresentation(op.getLoc(), type.getElementType(),
6563                                   batch_dims, vs, taus, m - i, k, &rewriter);
6564       auto y = vs;
6565 
6566       // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
6567       Value a_panel =
6568           SliceInMinorDims(op.getLoc(), a, {i, i + k}, {m, n}, &rewriter);
6569       auto a_update = BatchDot(op.getLoc(), w, true, a_panel, false,
6570                                batch_dims.size(), precision_config, &rewriter);
6571       a_update = BatchDot(op.getLoc(), y, false, a_update, false,
6572                           batch_dims.size(), precision_config, &rewriter);
6573       a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update);
6574       a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k},
6575                                  &rewriter);
6576 
6577       // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
6578       Value q_panel =
6579           SliceInMinorDims(op.getLoc(), q, {0, i}, {m, m}, &rewriter);
6580       Value q_update = BatchDot(op.getLoc(), q_panel, false, w, false,
6581                                 batch_dims.size(), precision_config, &rewriter);
6582       q_update = BatchDot(op.getLoc(), q_update, false, y, true,
6583                           batch_dims.size(), precision_config, &rewriter);
6584       q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update);
6585       q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter);
6586     }
6587     // full_matrices is false when only a partial result in needed. Slice to the
6588     // needed dimensions here.
6589     if (!op.full_matrices()) {
6590       q = SliceInMinorDims(op.getLoc(), q, {0, 0}, {m, p}, &rewriter);
6591       a = SliceInMinorDims(op.getLoc(), a, {0, 0}, {p, n}, &rewriter);
6592     }
6593     rewriter.replaceOp(op, {q, a});
6594     return success();
6595   }
6596 
6597  private:
6598   // Computes a Householder reflection of the form:
6599   // H = I - tau v v.T.
6600   // such that
6601   // H . ( x1  ) = ( x1   )
6602   //     ( x2  ) = ( x2   )
6603   //     ( ... ) = ( ...  )
6604   //     ( xk  ) = ( beta )
6605   //     ( ... )   ( 0    )
6606   //     ( ... )   ( 0    )
6607   // Unlike the usual formulation, we allow the caller to supply 'k' rather than
6608   // only providing the relevant part of 'x' to maintain XLA's static shape
6609   // invariant. In addition, the implementation supports batching.
6610   // Pseudo-code, without batching:
6611   //   alpha = x[k]
6612   //   x_copy = np.copy(x)
6613   //   x_copy[:k+1] = 0
6614   //   xnorm = norm2(x_copy)
6615   //   if xnorm == 0:
6616   //     beta = alpha
6617   //     tau = 0
6618   //     v = np.zeros_like(x)
6619   //   else:
6620   //     beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
6621   //     tau = (beta - alpha) / beta
6622   //     v = x / (alpha - beta)
6623   //   v[k] = 1
6624   //   return (v, tau, beta)
House(Location loc,Value x,Value k,ArrayRef<int64_t> batch_dims,const int64_t m,OpBuilder * builder,Value * v,Value * tau,Value * beta) const6625   void House(Location loc, Value x, Value k, ArrayRef<int64_t> batch_dims,
6626              const int64_t m, OpBuilder *builder, Value *v, Value *tau,
6627              Value *beta) const {
6628     auto x_type = x.getType().cast<RankedTensorType>();
6629 
6630     llvm::SmallVector<int64_t, 4> batch_dim_ids(batch_dims.size());
6631     std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
6632     const int64_t minor_dim = batch_dims.size();
6633 
6634     Value zero = GetScalarConstOfType(x_type.getElementType(), loc, 0, builder);
6635     Value one = GetScalarConstOfType(x_type.getElementType(), loc, 1, builder);
6636 
6637     // alpha = x[k]
6638     Value alpha = DynamicSliceInMinorDims(loc, x, {k}, {1}, builder);
6639     alpha = builder->create<ReshapeOp>(
6640         loc, RankedTensorType::get(batch_dims, x_type.getElementType()), alpha);
6641 
6642     // Compute x[k+1:] (padded with zeros in elements 0..k)
6643     Value iota = builder->create<IotaOp>(
6644         loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
6645         builder->getI64IntegerAttr(0));
6646     Value gtk = builder->create<chlo::BroadcastCompareOp>(
6647         loc, iota, k, GetI64ElementsAttr({}, builder), ComparisonDirection::GT);
6648     gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
6649     Value x_after_k = builder->create<chlo::BroadcastMulOp>(
6650         loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
6651     Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
6652     // sigma = np.dot(x[k+1:], x[k+1:])
6653     auto sigma = builder->create<ReduceOp>(
6654         loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder));
6655     BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder);
6656     // mu = np.sqrt(x[k]*x[k] + sigma)
6657     Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha);
6658     Value mu = builder->create<SqrtOp>(
6659         loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
6660 
6661     Value sigma_is_zero = builder->create<chlo::BroadcastCompareOp>(
6662         loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
6663         ComparisonDirection::EQ);
6664     Value alpha_is_negative = builder->create<chlo::BroadcastCompareOp>(
6665         loc, alpha, zero, GetI64ElementsAttr({}, builder),
6666         ComparisonDirection::LT);
6667     auto batch_size_one = builder->create<BroadcastOp>(
6668         loc, one, GetI64ElementsAttr(batch_dims, builder));
6669     Value signed_mu = builder->create<chlo::BroadcastMulOp>(
6670         loc,
6671         builder->create<SelectOp>(loc, alpha_is_negative, batch_size_one,
6672                                   builder->create<NegOp>(loc, batch_size_one)),
6673         mu, GetI64ElementsAttr({}, builder));
6674     *beta = builder->create<SelectOp>(loc, sigma_is_zero, alpha, signed_mu);
6675     *tau = builder->create<DivOp>(
6676         loc, builder->create<SubtractOp>(loc, *beta, alpha), *beta);
6677     Value zero_tau = builder->create<BroadcastOp>(
6678         loc, zero, GetI64ElementsAttr(batch_dims, builder));
6679     *tau = builder->create<SelectOp>(loc, sigma_is_zero, zero_tau, *tau);
6680     Value divisor = builder->create<SubtractOp>(loc, alpha, *beta);
6681     divisor =
6682         builder->create<SelectOp>(loc, sigma_is_zero, batch_size_one, divisor);
6683 
6684     Value eqk = builder->create<chlo::BroadcastCompareOp>(
6685         loc, iota, k, GetI64ElementsAttr({}, builder), ComparisonDirection::EQ);
6686     eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
6687     llvm::SmallVector<int64_t, 4> e_k_shape(batch_dims.size(), 1);
6688     e_k_shape.push_back(m);
6689     auto e_k = builder->create<BroadcastOp>(
6690         loc, eqk,
6691         GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(batch_dims.size(), 1),
6692                            builder));
6693 
6694     // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
6695     // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
6696     // Note that the add performs a degenerate broadcast.
6697     *v = builder->create<chlo::BroadcastAddOp>(
6698         loc, e_k,
6699         StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
6700                                      GetI64ElementsAttr(batch_dim_ids, builder),
6701                                      *builder),
6702         /*broadcast_dimensions=*/nullptr);
6703   }
6704 
6705   // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
6706   // Loan "Matrix Computations", 4th Edition. This is an unblocked
6707   // implementation used as an inner routine of the blocked implementation.
6708   // Algorithm is adapted slightly so the shapes inside the loop are static, at
6709   // the cost of some redundant computation. Since this is used as an inner
6710   // block kernel, accumulates the Householder transformations (vs, taus) rather
6711   // than the matrix q. Equivalent Python code, without batching: def qr(a):
6712   //   m = a.shape[0]
6713   //   n = a.shape[1]
6714   //   vs = np.zeros([m, n])
6715   //   taus = np.zeros([n])
6716   //   for j in xrange(min(m, n)):
6717   //     v, tau, beta = house(a[:, j], j)
6718   //     # Unusually, we apply the Householder transformation to the entirety of
6719   //     # a, wasting FLOPs to maintain the static shape invariant that XLA
6720   //     # requires. For columns that precede j this has no effect.
6721   //     a[:, :] -= tau * np.dot(v[:, np.newaxis],
6722   //                              np.dot(v[np.newaxis, :], a[:, :]))
6723   //     # Form column j explicitly rather than relying on the precision of the
6724   //     # Householder update.
6725   //     a[j, j] = beta
6726   //     a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
6727   //     vs[:, j] = v
6728   //     taus[j] = tau
6729   //   return (q, vs, taus)
QRBlock(Location loc,Value a,Value * r,Value * taus,Value * vs,PatternRewriter * rewriter) const6730   void QRBlock(Location loc, Value a, Value *r, Value *taus, Value *vs,
6731                PatternRewriter *rewriter) const {
6732     auto a_type = a.getType().cast<RankedTensorType>();
6733     const int num_dims = a_type.getRank();
6734     assert(num_dims >= 2 && "Argument to QR must have rank >= 2");
6735 
6736     const int64_t m = a_type.getDimSize(a_type.getRank() - 2);
6737     const int64_t n = a_type.getDimSize(a_type.getRank() - 1);
6738 
6739     const int64_t num_batch_dims = num_dims - 2;
6740     auto batch_dims = a_type.getShape().take_front(num_batch_dims);
6741     llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
6742     std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
6743 
6744     auto qr_body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
6745                           SmallVectorImpl<Value> *new_values,
6746                           OpBuilder *builder) {
6747       auto a = old_values[0];
6748       auto vs = old_values[1];
6749       auto taus = old_values[2];
6750 
6751       // v, beta = house(a[:, j], j)
6752       auto x = DynamicSliceInMinorDims(loc, a, {j}, {1}, builder);
6753       auto x_collapsed_shape = llvm::to_vector<4>(batch_dims);
6754       x_collapsed_shape.push_back(m);
6755       auto x_collapsed = builder->create<ReshapeOp>(
6756           loc,
6757           RankedTensorType::get(x_collapsed_shape,
6758                                 getElementTypeOrSelf(x.getType())),
6759           x);
6760       Value v, tau, beta;
6761       House(loc, x_collapsed, j, batch_dims, m, builder, &v, &tau, &beta);
6762 
6763       auto shape = llvm::to_vector<4>(batch_dims);
6764       shape.append({1, m});
6765       auto v_broadcast = builder->create<ReshapeOp>(
6766           loc, RankedTensorType::get(shape, getElementTypeOrSelf(v.getType())),
6767           v);
6768       // a[:, :] -= tau * np.dot(v[:, np.newaxis],
6769       //                          np.dot(v[np.newaxis, :], a[:, :]))
6770       auto precision = builder->getArrayAttr(
6771           {PrecisionAttr::get(builder->getContext(), Precision::HIGHEST),
6772            PrecisionAttr::get(builder->getContext(), Precision::HIGHEST)});
6773       auto vva = BatchDot(loc, v_broadcast, false, a, false, num_batch_dims,
6774                           precision, builder);
6775       vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
6776                      precision, builder);
6777       auto tau_x_vva = StaticBinaryBroadcast<mhlo::MulOp>(
6778           loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
6779           *builder);
6780       a = builder->create<SubtractOp>(loc, a, tau_x_vva);
6781 
6782       // It is more precise to populate column 'k' explicitly, rather than
6783       // computing it implicitly by applying the Householder transformation.
6784       // a[k,k] = beta
6785       // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
6786       auto iota = builder->create<IotaOp>(
6787           loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
6788           builder->getI64IntegerAttr(0));
6789       Value predecessor_mask = builder->create<chlo::BroadcastCompareOp>(
6790           loc, iota, j, GetI64ElementsAttr({}, builder),
6791           ComparisonDirection::LT);
6792       predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
6793                                                     a_type.getElementType());
6794       Value mask = builder->create<chlo::BroadcastCompareOp>(
6795           loc, iota, j, GetI64ElementsAttr({}, builder),
6796           ComparisonDirection::EQ);
6797       mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
6798       mask = builder->create<BroadcastOp>(
6799           loc,
6800           mask,
6801           GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
6802                              builder));
6803       Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>(
6804           loc, x, predecessor_mask,
6805           GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder);
6806       Value masked_beta = StaticBinaryBroadcast<MulOp>(
6807           loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder),
6808           *builder);
6809       Value new_x =
6810           builder->create<AddOp>(loc, predecessor_masked_x, masked_beta);
6811       // Update a[:,j]
6812       llvm::SmallVector<int64_t, 4> dim_ids(num_dims);
6813       std::iota(dim_ids.begin(), dim_ids.end(), 0);
6814       new_x = builder->create<BroadcastInDimOp>(
6815           loc, a_type, new_x, GetI64ElementsAttr(dim_ids, builder));
6816       const int64_t minor_dim = num_batch_dims;
6817       auto iota_mn = builder->create<IotaOp>(
6818           loc,
6819           RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
6820           builder->getI64IntegerAttr(minor_dim + 1));
6821       Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
6822           loc, iota_mn, j, GetI64ElementsAttr({}, builder),
6823           ComparisonDirection::EQ);
6824       a = builder->create<SelectOp>(loc, xa_mask, new_x, a);
6825 
6826       // vs[:, j] = v
6827       llvm::SmallVector<int64_t, 4> vs_broadcast_dims(num_batch_dims + 1);
6828       std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0);
6829       Value vs_zeros =
6830           GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
6831       vs_zeros = builder->create<BroadcastOp>(
6832           loc, vs_zeros,
6833           GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
6834                              builder));
6835       auto vs_update = builder->create<SelectOp>(
6836           loc, xa_mask,
6837           StaticBinaryBroadcast<AddOp>(
6838               loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
6839               *builder),
6840           vs_zeros);
6841       vs = builder->create<AddOp>(loc, vs, vs_update);
6842 
6843       // taus[j] = tau
6844       llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size());
6845       std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
6846 
6847       auto iota_shape = llvm::to_vector<4>(batch_dims);
6848       iota_shape.push_back(n);
6849       auto iota_n = builder->create<IotaOp>(
6850           loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
6851           builder->getI64IntegerAttr(minor_dim));
6852       Value taus_zeros =
6853           GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
6854       taus_zeros = builder->create<BroadcastOp>(
6855           loc, taus_zeros,
6856           GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
6857                              builder));
6858       Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
6859           loc, iota_n, j, GetI64ElementsAttr({}, builder),
6860           ComparisonDirection::EQ);
6861       auto taus_update = builder->create<SelectOp>(
6862           loc, taus_mask,
6863           StaticBinaryBroadcast<AddOp>(
6864               loc, taus_zeros, tau,
6865               GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
6866           taus_zeros);
6867       taus = builder->create<AddOp>(loc, taus, taus_update);
6868       new_values->assign({a, vs, taus});
6869     };
6870 
6871     Value zero =
6872         GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter);
6873     *vs = rewriter->create<BroadcastOp>(
6874         loc, zero, GetI64ElementsAttr(a_type.getShape(), rewriter));
6875     auto taus_shape = llvm::to_vector<4>(batch_dims);
6876     taus_shape.push_back(n);
6877     *taus = rewriter->create<BroadcastOp>(
6878         loc, zero, GetI64ElementsAttr(taus_shape, rewriter));
6879 
6880     SmallVector<Value, 4> while_output;
6881     CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus},
6882                   &while_output, rewriter);
6883     *r = while_output[0];
6884     *vs = while_output[1];
6885     *taus = while_output[2];
6886   }
6887 
6888   // Computes W and Y such that I-WY is equivalent to the sequence of
6889   // Householder
6890   // transformations given by vs and taus.
6891   // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
6892   // Y = np.zeros([m, n])
6893   // W = np.zeros([m, n])
6894   // Y[:, 0] = vs[:, 0]
6895   // W[:, 0] = -taus[0] * vs[:, 0]
6896   // for j in xrange(1, n):
6897   //   v = vs[:, j]
6898   //   z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
6899   //   W[:, j] = z
6900   //   Y[:, j] = v
6901   // return W
6902   // There is no need to return Y since at termination of the loop it is equal
6903   // to vs.
ComputeWYRepresentation(Location loc,Type data_type,ArrayRef<int64_t> batch_dims,Value vs,Value taus,int64_t m,int64_t n,PatternRewriter * rewriter) const6904   Value ComputeWYRepresentation(Location loc, Type data_type,
6905                                 ArrayRef<int64_t> batch_dims, Value vs,
6906                                 Value taus, int64_t m, int64_t n,
6907                                 PatternRewriter *rewriter) const {
6908     int64_t n_index = batch_dims.size() + 1;
6909     llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
6910     std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
6911 
6912     auto body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
6913                        SmallVectorImpl<Value> *new_values, OpBuilder *builder) {
6914       // w has shape [..., m, n]
6915       auto w = old_values[0];
6916       const auto vs = old_values[1];
6917       const auto taus = old_values[2];
6918 
6919       // Want j values in range [1, ... n).
6920       j = builder->create<AddOp>(
6921           loc, j,
6922           GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1,
6923                                builder));
6924       // vs has shape [..., m, 1]
6925       auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder);
6926       // beta has shape [..., 1]
6927       auto beta = DynamicSliceInMinorDims(loc, taus, {j}, {1}, builder);
6928 
6929       auto iota_shape = llvm::to_vector<4>(batch_dims);
6930       iota_shape.append({m, n});
6931       auto iota_mn = builder->create<IotaOp>(
6932           loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
6933           builder->getI64IntegerAttr(n_index));
6934 
6935       // y has shape [..., m, n]
6936       Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc,
6937                                         0, builder);
6938       zero = builder->create<BroadcastOp>(
6939           loc, zero,
6940           GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
6941                              builder));
6942       auto compare = builder->create<chlo::BroadcastCompareOp>(
6943           loc, iota_mn, j, GetI64ElementsAttr({}, builder),
6944           ComparisonDirection::GE);
6945       auto y = builder->create<SelectOp>(loc, compare, zero, vs);
6946 
6947       // yv has shape [..., n, 1]
6948       auto precision = builder->getArrayAttr(
6949           {PrecisionAttr::get(builder->getContext(), Precision::HIGHEST),
6950            PrecisionAttr::get(builder->getContext(), Precision::HIGHEST)});
6951       auto yv = BatchDot(loc, y, true, v, false, batch_dims.size(), precision,
6952                          builder);
6953       // wyv has shape [..., m, 1]
6954       auto wyv = BatchDot(loc, w, false, yv, false, batch_dims.size(),
6955                           precision, builder);
6956 
6957       // z = -beta * (v + wyv)
6958       auto neg_beta = builder->create<NegOp>(loc, beta);
6959       auto v_wyv = builder->create<AddOp>(loc, v, wyv);
6960       auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
6961       beta_broadcast_dims.push_back(n_index);
6962       auto z = StaticBinaryBroadcast<MulOp>(
6963           loc, neg_beta, v_wyv,
6964           GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter);
6965 
6966       w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder);
6967       new_values->assign({w, vs, taus});
6968     };
6969 
6970     Value w =
6971         GetScalarConstOfType(getElementTypeOrSelf(data_type), loc, 0, rewriter);
6972     auto w_shape = llvm::to_vector<4>(batch_dims);
6973     w_shape.append({m, n});
6974     w = rewriter->create<BroadcastOp>(loc,
6975                                       w, GetI64ElementsAttr(w_shape, rewriter));
6976     auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter);
6977     auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter);
6978     auto neg_beta = rewriter->create<NegOp>(loc, beta);
6979     auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
6980     beta_broadcast_dims.push_back(n_index);
6981     auto bv = StaticBinaryBroadcast<MulOp>(
6982         loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter),
6983         *rewriter);
6984     w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter);
6985 
6986     SmallVector<Value, 4> while_output;
6987     CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter);
6988     return while_output[0];
6989   }
6990 };
6991 
6992 // Converts tf.XlaConvV2 to mhlo.Conv
6993 class ConvertXlaConvV2Op : public OpRewritePattern<TF::XlaConvV2Op> {
6994  public:
6995   using OpRewritePattern::OpRewritePattern;
6996 
matchAndRewrite(TF::XlaConvV2Op op,PatternRewriter & rewriter) const6997   LogicalResult matchAndRewrite(TF::XlaConvV2Op op,
6998                                 PatternRewriter &rewriter) const override {
6999     DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr,
7000         rhs_dilation_attr, feature_group_count_attr;
7001     if (!(matchPattern(op.window_strides(), m_Constant(&window_strides_attr)) &&
7002           matchPattern(op.padding(), m_Constant(&padding_attr)) &&
7003           matchPattern(op.lhs_dilation(), m_Constant(&lhs_dilation_attr)) &&
7004           matchPattern(op.rhs_dilation(), m_Constant(&rhs_dilation_attr)) &&
7005           matchPattern(op.feature_group_count(),
7006                        m_Constant(&feature_group_count_attr))))
7007       return failure();
7008 
7009     auto window_strides_named_attr = rewriter.getNamedAttr(
7010         "window_strides", hlo::convertElementsAttr(window_strides_attr,
7011                                                    rewriter.getIntegerType(64))
7012                               .cast<DenseIntElementsAttr>());
7013 
7014     auto padding_named_attr = rewriter.getNamedAttr(
7015         "padding",
7016         hlo::convertElementsAttr(padding_attr, rewriter.getIntegerType(64))
7017             .cast<DenseIntElementsAttr>());
7018 
7019     auto lhs_dilation_named_attr = rewriter.getNamedAttr(
7020         "lhs_dilation",
7021         hlo::convertElementsAttr(lhs_dilation_attr, rewriter.getIntegerType(64))
7022             .cast<DenseIntElementsAttr>());
7023 
7024     auto rhs_dilation_named_attr = rewriter.getNamedAttr(
7025         "rhs_dilation",
7026         hlo::convertElementsAttr(rhs_dilation_attr, rewriter.getIntegerType(64))
7027             .cast<DenseIntElementsAttr>());
7028 
7029     int64_t feature_group_count_val =
7030         feature_group_count_attr.getValues<IntegerAttr>()[0].getInt();
7031     auto feature_group_count_named_attr = rewriter.getNamedAttr(
7032         "feature_group_count",
7033         rewriter.getI64IntegerAttr(feature_group_count_val));
7034 
7035     auto batch_group_count_named_attr =
7036         rewriter.getNamedAttr("batch_group_count", op.batch_group_countAttr());
7037 
7038     xla::ConvolutionDimensionNumbers dnums;
7039     dnums.ParseFromString(op.dimension_numbersAttr().getValue().str());
7040     auto dimension_numbers_named_attr = rewriter.getNamedAttr(
7041         "dimension_numbers",
7042         xla::ConvertConvDimensionNumbers(dnums, &rewriter));
7043 
7044     xla::PrecisionConfig precision_config;
7045     precision_config.ParseFromString(
7046         op.precision_configAttr().getValue().str());
7047     auto precision_config_named_attr = rewriter.getNamedAttr(
7048         "precision_config",
7049         xla::ConvertPrecisionConfig(&precision_config, &rewriter));
7050 
7051     SmallVector<Value, 2> operands{op.lhs(), op.rhs()};
7052     NamedAttribute attrs[] = {
7053         window_strides_named_attr,      padding_named_attr,
7054         lhs_dilation_named_attr,        rhs_dilation_named_attr,
7055         feature_group_count_named_attr, batch_group_count_named_attr,
7056         dimension_numbers_named_attr,   precision_config_named_attr};
7057     rewriter.replaceOpWithNewOp<mhlo::ConvolutionOp>(op, op.getType(), operands,
7058                                                      llvm::makeArrayRef(attrs));
7059     return success();
7060   }
7061 };
7062 
7063 // Converts tf.XlaSelectAndScatter to mhlo.SelectAndScatter
7064 class ConvertXlaSelectAndScatterOp
7065     : public OpRewritePattern<TF::XlaSelectAndScatterOp> {
7066  public:
7067   using OpRewritePattern::OpRewritePattern;
7068 
matchAndRewrite(TF::XlaSelectAndScatterOp op,PatternRewriter & rewriter) const7069   LogicalResult matchAndRewrite(TF::XlaSelectAndScatterOp op,
7070                                 PatternRewriter &rewriter) const override {
7071     ElementsAttr window_dimensions, window_strides, padding;
7072     if (!(matchPattern(op.window_dimensions(),
7073                        m_Constant(&window_dimensions)) &&
7074           matchPattern(op.window_strides(), m_Constant(&window_strides)) &&
7075           matchPattern(op.padding(), m_Constant(&padding))))
7076       return failure();
7077 
7078     Location loc = op.getLoc();
7079 
7080     SmallVector<Type> result_types{op.getResult().getType()};
7081     // Create the mhlo.SelectAndScatter op.
7082     auto select_and_scatter_op = rewriter.create<mhlo::SelectAndScatterOp>(
7083         loc, result_types, op.operand(), op.source(), op.init_value(),
7084         hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64))
7085             .cast<DenseIntElementsAttr>(),
7086         hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64))
7087             .cast<DenseIntElementsAttr>(),
7088         hlo::convertElementsAttr(padding, rewriter.getIntegerType(64))
7089             .cast<DenseIntElementsAttr>());
7090 
7091     auto insert_call_to = [&](const mlir::SymbolRefAttr &func, Region *region) {
7092       auto func_op = cast<mlir::func::FuncOp>(SymbolTable::lookupSymbolIn(
7093           op->getParentOfType<mlir::ModuleOp>(), func));
7094       auto func_ty = func_op.getFunctionType();
7095       BuildBodyWithCall(rewriter, loc, func, func_ty, region);
7096     };
7097 
7098     // Insert a call to the select function in the select region of the mhlo op.
7099     insert_call_to(op.select(), &select_and_scatter_op.select());
7100     // Insert a call to the scatter function in the scatter region of the mhlo
7101     // op.
7102     insert_call_to(op.scatter(), &select_and_scatter_op.scatter());
7103 
7104     rewriter.replaceOp(op, select_and_scatter_op.getResult());
7105 
7106     return success();
7107   }
7108 };
7109 
7110 // Convert tf.XlaSort to mhlo.Sort
7111 class ConvertXlaSortOp : public OpRewritePattern<TF::XlaSortOp> {
7112  public:
7113   using OpRewritePattern::OpRewritePattern;
7114 
matchAndRewrite(TF::XlaSortOp op,PatternRewriter & rewriter) const7115   LogicalResult matchAndRewrite(TF::XlaSortOp op,
7116                                 PatternRewriter &rewriter) const override {
7117     // Create the sort op.
7118     Type element_type = getElementTypeOrSelf(op.input().getType());
7119     auto sort_op =
7120         createSortOp(&rewriter, op.getLoc(), {op.input()}, {element_type},
7121                      /*dimension=*/-1, /*is_stable=*/false,
7122                      /*direction=*/ComparisonDirection::LT);
7123     rewriter.replaceOp(op, sort_op.getResult(0));
7124     return success();
7125   }
7126 };
7127 
TensorFlowRngAlgToXla(tensorflow::Algorithm alg)7128 inline llvm::Optional<xla::RandomAlgorithm> TensorFlowRngAlgToXla(
7129     tensorflow::Algorithm alg) {
7130   if (alg == tensorflow::RNG_ALG_PHILOX) {
7131     return xla::RandomAlgorithm::RNG_PHILOX;
7132   } else if (alg == tensorflow::RNG_ALG_THREEFRY) {
7133     return xla::RandomAlgorithm::RNG_THREE_FRY;
7134   } else if (alg == tensorflow::RNG_ALG_AUTO_SELECT) {
7135     return xla::RandomAlgorithm::RNG_DEFAULT;
7136   }
7137   return llvm::None;
7138 }
7139 
7140 // Converts tf.XlaRngBitGenerator op to mhlo.RngBitGenerator op.
7141 class ConvertXlaRngBitGeneratorOp
7142     : public OpRewritePattern<TF::XlaRngBitGeneratorOp> {
7143  public:
7144   using OpRewritePattern::OpRewritePattern;
7145 
matchAndRewrite(TF::XlaRngBitGeneratorOp op,PatternRewriter & rewriter) const7146   LogicalResult matchAndRewrite(TF::XlaRngBitGeneratorOp op,
7147                                 PatternRewriter &rewriter) const override {
7148     Location loc = op.getLoc();
7149     DenseElementsAttr algorithm;
7150     if (!(matchPattern(op.algorithm(), m_Constant(&algorithm))) ||
7151         algorithm.getType().getRank()) {
7152       return op.emitOpError() << "algorithm must be a constant scalar";
7153     }
7154     auto alg = static_cast<tensorflow::Algorithm>(
7155         algorithm.getValues<IntegerAttr>()[0].getInt());
7156     auto xla_alg = TensorFlowRngAlgToXla(alg);
7157     if (!xla_alg) {
7158       return op.emitOpError() << "unknown algorithm";
7159     }
7160 
7161     auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
7162         rewriter.getContext(),
7163         *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.getValue()));
7164     auto rng_bit_generator_op = rewriter.create<mhlo::RngBitGeneratorOp>(
7165         loc, op.getResultTypes(), algorithm_attr, op.initial_state());
7166 
7167     rewriter.replaceOp(op, rng_bit_generator_op.getResults());
7168 
7169     return success();
7170   }
7171 };
7172 
7173 // Converts tf.XlaVariadicReduceV2 to mhlo.Reduce
7174 class ConvertXlaVariadicReduceV2Op
7175     : public OpRewritePattern<TF::XlaVariadicReduceV2Op> {
7176  public:
7177   using OpRewritePattern::OpRewritePattern;
7178 
matchAndRewrite(TF::XlaVariadicReduceV2Op op,PatternRewriter & rewriter) const7179   LogicalResult matchAndRewrite(TF::XlaVariadicReduceV2Op op,
7180                                 PatternRewriter &rewriter) const override {
7181     Location loc = op.getLoc();
7182 
7183     // Create the mhlo.reduce op.
7184     auto reduce_op = rewriter.create<mhlo::ReduceOp>(
7185         loc, op.inputs(), op.init_values(),
7186         GetI64ElementsAttr(op.dimensions_to_reduce()));
7187     mlir::SymbolRefAttr func = op.reducer();
7188     auto func_op = cast<mlir::func::FuncOp>(SymbolTable::lookupSymbolIn(
7189         op->getParentOfType<mlir::ModuleOp>(), func));
7190     auto func_ty = func_op.getFunctionType();
7191     // Insert a call to the reducer in the region of the mhlo op.
7192     BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_op.body());
7193 
7194     rewriter.replaceOp(op, reduce_op.getResults());
7195 
7196     return success();
7197   }
7198 };
7199 
7200 // Convert tf.XlaVariadicSort to mhlo.Sort
7201 class ConvertXlaVariadicSortOp
7202     : public OpRewritePattern<TF::XlaVariadicSortOp> {
7203  public:
7204   using OpRewritePattern::OpRewritePattern;
7205 
matchAndRewrite(TF::XlaVariadicSortOp op,PatternRewriter & rewriter) const7206   LogicalResult matchAndRewrite(TF::XlaVariadicSortOp op,
7207                                 PatternRewriter &rewriter) const override {
7208     Location loc = op.getLoc();
7209     ElementsAttr dimension;
7210     matchPattern(op.dimension(), m_Constant(&dimension));
7211     // Create the mhlo.sort op.
7212     auto sort_op = rewriter.create<mhlo::SortOp>(
7213         loc, op.inputs(), dimension.getValues<IntegerAttr>()[0].getInt(),
7214         op.is_stable());
7215     mlir::SymbolRefAttr func = op.comparator();
7216     auto func_op = cast<mlir::func::FuncOp>(SymbolTable::lookupSymbolIn(
7217         op->getParentOfType<mlir::ModuleOp>(), func));
7218     auto func_ty = func_op.getFunctionType();
7219     // Insert a call to the reducer in the region of the mhlo op.
7220     BuildBodyWithCall(rewriter, loc, func, func_ty, &sort_op.comparator());
7221 
7222     rewriter.replaceOp(op, sort_op.getResults());
7223     return success();
7224   }
7225 };
7226 
7227 // Convert tf.XlaReducePrecision to mhlo.ReducePrecision
7228 class ConvertXlaReducePrecisionOp
7229     : public OpRewritePattern<TF::XlaReducePrecisionOp> {
7230  public:
7231   using OpRewritePattern::OpRewritePattern;
7232 
matchAndRewrite(TF::XlaReducePrecisionOp op,PatternRewriter & rewriter) const7233   LogicalResult matchAndRewrite(TF::XlaReducePrecisionOp op,
7234                                 PatternRewriter &rewriter) const override {
7235     IntegerType int32_type = rewriter.getIntegerType(32);
7236     APInt exponent_bits = op.exponent_bitsAttr().getValue();
7237     // Truncating to 32-bits is safe, since pasing any number above the dtype
7238     // size (which is at most 64, for float64) is equivalent to passing the
7239     // dtype size.
7240     IntegerAttr new_exponent_attr =
7241         IntegerAttr::get(int32_type, exponent_bits.truncSSat(32));
7242     APInt mantissa_bits = op.mantissa_bitsAttr().getValue();
7243     IntegerAttr new_mantissa_attr =
7244         IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32));
7245     rewriter.replaceOpWithNewOp<mhlo::ReducePrecisionOp>(
7246         op, op.getType(), op.operand(), new_exponent_attr, new_mantissa_attr);
7247     return success();
7248   }
7249 };
7250 
7251 }  // end namespace
7252 
7253 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
7254 
PopulateLegalizeTfPatterns(MLIRContext * context,RewritePatternSet * patterns)7255 void PopulateLegalizeTfPatterns(MLIRContext *context,
7256                                 RewritePatternSet *patterns) {
7257   populateWithGenerated(*patterns);
7258   // clang-format off
7259   patterns->add<
7260     ConvertAllOp,
7261     ConvertAnyOp,
7262     ConvertArgMaxOp,
7263     ConvertArgMinOp,
7264     ConvertBatchMatMulV2Op,
7265     ConvertBiasAddOp,
7266     ConvertBroadcastToOp,
7267     ConvertBF16FloorDivOp,
7268     ConvertClipByValueOp,
7269     ConvertConstOp,
7270     ConvertConv2DOp,
7271     ConvertConv3DOp,
7272     ConvertDepthConv2DOp,
7273     ConvertConv2DBackpropFilterOp,
7274     ConvertConv3DBackpropFilterOp,
7275     ConvertConv2DBackpropInputOp,
7276     ConvertConv3DBackpropInputOp,
7277     ConvertCumprodOp,
7278     ConvertCumsumOp,
7279     ConvertDiagPartOp,
7280     ConvertDynamicExpandDimsOp,
7281     ConvertDynamicSqueezeOp,
7282     ConvertEinsumOp,
7283     ConvertRFFTOp,
7284     ConvertIRFFTOp,
7285     ConvertFusedBatchNormGradOp,
7286     ConvertFusedBatchNormGradV2Op,
7287     ConvertFusedBatchNormGradV3Op,
7288     ConvertFusedBatchNormV2Op,
7289     ConvertFusedBatchNormV3Op,
7290     ConvertInfeedDequeueTupleOp,
7291     ConvertIdentityNOp,
7292     ConvertInplaceUpdateOp,
7293     ConvertLinSpaceOp,
7294     ConvertMaxOp,
7295     ConvertMinOp,
7296     ConvertAvgPool2DOp,
7297     ConvertAvgPool3DOp,
7298     ConvertAvgPool2DGradOp,
7299     ConvertAvgPool3DGradOp,
7300     ConvertMaxPool2DOp,
7301     ConvertMaxPool3DOp,
7302     ConvertMaxPool2DGradOp,
7303     ConvertMaxPool3DGradOp,
7304     ConvertMeanOp,
7305     ConvertOneHotOp,
7306     ConvertOutfeedEnqueueTupleOp,
7307     ConvertProdOp,
7308     ConvertQrOp,
7309     ConvertDynamicRangeOp,
7310     ConvertMatrixDiagPartV3Op,
7311     ConvertRangeOp,
7312     ConvertSelectOp,
7313     ConvertSigmoidOp,
7314     ConvertShapeOp,
7315     ConvertSplitOp,
7316     ConvertSplitVOp,
7317     ConvertStridedSliceOp,
7318     ConvertStridedSliceGradOp,
7319     ConvertSumOp,
7320     ConvertTensorScatterAddOp,
7321     ConvertTensorScatterSubOp,
7322     ConvertTensorScatterMinOp,
7323     ConvertTensorScatterMaxOp,
7324     ConvertTensorScatterUpdateOp,
7325     ConvertTileOp,
7326     ConvertTopKV2Op,
7327     ConvertUnpackOp,
7328     ConvertUnsortedSegmentMaxOp,
7329     ConvertUnsortedSegmentMinOp,
7330     ConvertUnsortedSegmentProdOp,
7331     ConvertUnsortedSegmentSumOp,
7332     ConvertRandomShuffleOp,
7333     ConvertXlaShardingOp,
7334     ConvertXlaDynamicUpdateSliceOp,
7335     ConvertXlaConvV2Op,
7336     ConvertXlaReducePrecisionOp,
7337     ConvertXlaReduceScatterOp,
7338     ConvertXlaReduceWindowOp,
7339     ConvertXlaRngBitGeneratorOp,
7340     ConvertXlaSelectAndScatterOp,
7341     ConvertXlaSortOp,
7342     ConvertXlaVariadicReduceV2Op,
7343     ConvertXlaVariadicSortOp,
7344     ConvertRollOp,
7345     ConvertLeakyReluOp,
7346     ConvertLeakyReluGradOp,
7347     ConvertSplitOpDynamic,
7348     ConvertSliceOpDynamic,
7349     ConvertTileOpDynamic,
7350     ConvertUnpackOpDynamic,
7351     ConvertSigmoidGradOpDynamic,
7352     ConvertConv2DDynamic,
7353     ConvertPadOpDynamic,
7354     ConvertGatherNdOpDynamic>(context);
7355   // clang-format on
7356 }
7357 
7358 }  // end namespace mhlo
7359 }  // end namespace mlir
7360