1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect to XLA dialect.
17
18 #include <cctype>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <limits>
23 #include <numeric>
24
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/Optional.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/Sequence.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/ErrorHandling.h"
32 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
33 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
34 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
35 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
36 #include "mlir/Dialect/Traits.h" // from @llvm-project
37 #include "mlir/IR/Attributes.h" // from @llvm-project
38 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
39 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
40 #include "mlir/IR/Diagnostics.h" // from @llvm-project
41 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
42 #include "mlir/IR/MLIRContext.h" // from @llvm-project
43 #include "mlir/IR/Matchers.h" // from @llvm-project
44 #include "mlir/IR/Operation.h" // from @llvm-project
45 #include "mlir/IR/PatternMatch.h" // from @llvm-project
46 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
47 #include "mlir/IR/Types.h" // from @llvm-project
48 #include "mlir/Pass/Pass.h" // from @llvm-project
49 #include "mlir/Support/LogicalResult.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
51 #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
52 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
53 #include "tensorflow/compiler/mlir/xla/transforms/utils.h"
54 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
55 #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h"
56 #include "tensorflow/compiler/xla/client/padding.h"
57 #include "tensorflow/compiler/xla/client/sharding_builder.h"
58 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
59 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
60 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h"
61 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/hlo_utils.h"
62 #include "tensorflow/compiler/xla/xla_data.pb.h"
63 #include "tensorflow/core/framework/kernel_shape_util.h"
64 #include "tensorflow/core/framework/rng_alg.h"
65 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
66 #include "tensorflow/core/platform/bfloat16.h"
67 #include "tensorflow/core/util/padding.h"
68 #include "tensorflow/core/util/tensor_format.h"
69
70 namespace mlir {
71 namespace mhlo {
72 namespace {
73
74 constexpr char kShardingAttr[] = "mhlo.sharding";
75
76 /// Returns the feature dimension for the given format and input type.
GetFeatureDimension(tensorflow::TensorFormat format,RankedTensorType input_ty)77 static size_t GetFeatureDimension(tensorflow::TensorFormat format,
78 RankedTensorType input_ty) {
79 return GetTensorFeatureDimIndex(input_ty.getRank(), format);
80 }
81
82 // Gets all integer values from the given attribute and push them to `values`.
GetI64ArrayAttrValues(Attribute attr,SmallVectorImpl<int64_t> * values)83 void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl<int64_t> *values) {
84 auto array_attr = attr.cast<ArrayAttr>();
85 values->reserve(array_attr.getValue().size());
86 for (Attribute val : array_attr.getValue())
87 values->push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
88 }
89
90 // Returns 1D 32-bit dense elements attribute with the given values.
GetI32ElementsAttr(ArrayRef<int32_t> values,Builder * builder)91 static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef<int32_t> values,
92 Builder *builder) {
93 RankedTensorType ty = RankedTensorType::get(
94 {static_cast<int32_t>(values.size())}, builder->getIntegerType(32));
95 return DenseIntElementsAttr::get(ty, values);
96 }
97
98 // Returns a 1-d i64 elements attribute populated with numbers from start to
99 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)100 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
101 Builder *builder) {
102 int size = end - start;
103
104 SmallVector<int64_t, 4> vals;
105 vals.resize(size);
106 std::iota(vals.begin(), vals.end(), start);
107
108 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
109 return DenseIntElementsAttr::get(ty, vals);
110 }
111
112 // Returns a 1-d i64 elements attribute populated with `val` repeated `size`
113 // times.
GetI64ElementsAttrForValue(int size,int64_t val,Builder * builder)114 static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val,
115 Builder *builder) {
116 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
117 return DenseIntElementsAttr::get(ty, val);
118 }
119
120 // Returns the corresponding type that should be used for performing sum
121 // accumulation over the given input type.
GetSumAccumulationType(Type input_type)122 Type GetSumAccumulationType(Type input_type) {
123 MLIRContext *ctx = input_type.getContext();
124 if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx);
125 if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16))
126 return IntegerType::get(ctx, 32);
127 return input_type;
128 }
129
130 // Returns axis in HLO format from TF elements attr with exactly one element or
131 // is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow
132 // format supports negative indexing unlike HLO.
GetHLOAxisFromTFAxis(Attribute attr,int64_t rank,Builder * b)133 static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank,
134 Builder *b) {
135 IntegerAttr intAttr = attr.dyn_cast_or_null<IntegerAttr>();
136 if (auto elementAttr = attr.dyn_cast_or_null<ElementsAttr>()) {
137 SmallVector<uint64_t, 1> index(elementAttr.getType().getRank(), 0);
138 intAttr = elementAttr.getValues<IntegerAttr>()[index];
139 }
140
141 assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis");
142
143 int64_t axis = intAttr.getInt();
144 if (axis < 0) {
145 axis += rank;
146 }
147 return b->getI64IntegerAttr(axis);
148 }
149
150 // If `value` is an IntegerAttr, returns the integer value for the HLO axis
151 // corresponding to the tensorflow axis. In particular, the tensorflow axis can
152 // be negative, in which case, the corresponding HLO axis is
153 // (axis + rank-of-the-tensor).
GetIntegerHLOAxisFromTFAxis(Value value,int64_t rank)154 static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value value,
155 int64_t rank) {
156 DenseIntElementsAttr attrs;
157 if (!matchPattern(value, m_Constant(&attrs)) ||
158 attrs.getType().getRank() != 0) {
159 return llvm::None;
160 }
161 int64_t axis = attrs.getValues<IntegerAttr>()[0].getInt();
162 return axis < 0 ? axis + rank : axis;
163 }
164
165 /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining
166 /// the shape of the input value.
CastValueToI64(Location loc,Value value,PatternRewriter * rewriter)167 static ConvertOp CastValueToI64(Location loc, Value value,
168 PatternRewriter *rewriter) {
169 return rewriter->create<ConvertOp>(loc, value, rewriter->getIntegerType(64));
170 }
171
172 // Creates an unpack op along the 0th dimension of the tensor. The `value` input
173 // must be a ranked tensor.
UnpackTensorAlongZeroDim(Location loc,Value value,PatternRewriter * rewriter)174 static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value,
175 PatternRewriter *rewriter) {
176 auto indices_type = value.getType().cast<RankedTensorType>();
177 int num_outputs = indices_type.getShape().front();
178 SmallVector<Type, 2> unpacked_indices_type(
179 num_outputs, RankedTensorType::get({}, indices_type.getElementType()));
180 auto unpacked_indices = rewriter->create<TF::UnpackOp>(
181 loc, unpacked_indices_type, value,
182 IntegerAttr::get(rewriter->getIntegerType(64), 0));
183 return unpacked_indices;
184 }
185
186 // Returns size of dimension at the specified index, if ranked tensor.
187 // Otherwise, returns -1.
188 //
189 // Aborts if the type is ranked but doesn't have the dimension.
GetDimSize(Type ty,int64_t index)190 int64_t GetDimSize(Type ty, int64_t index) {
191 RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
192 if (!ranked_ty) return -1;
193
194 return ranked_ty.getDimSize(index);
195 }
196
197 template <typename T, int num_dims>
ToTensorShape(llvm::ArrayRef<T> sizes)198 tensorflow::TensorShape ToTensorShape(llvm::ArrayRef<T> sizes) {
199 return tensorflow::TensorShape(
200 llvm::SmallVector<int64_t, num_dims>(sizes.begin(), sizes.end()));
201 }
202
203 template <typename T, int num_dims>
ToTensorShape(llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes)204 tensorflow::TensorShape ToTensorShape(
205 llvm::iterator_range<DenseElementsAttr::ElementIterator<T>> sizes) {
206 return tensorflow::TensorShape(
207 llvm::SmallVector<int64_t, num_dims>(sizes.begin(), sizes.end()));
208 }
209
210 // Returns a limit scalar const op for the given type.
211 // Requires FloatType or IntegerType
GetScalarLimitConstOfType(Type ty,Location loc,hlo::ScalarLimit limit,OpBuilder * builder)212 static ConstantOp GetScalarLimitConstOfType(Type ty, Location loc,
213 hlo::ScalarLimit limit,
214 OpBuilder *builder) {
215 return builder->create<ConstantOp>(loc, hlo::getScalarLimitOfType(ty, limit));
216 }
217
218 // Creates an mhlo::SliceOp where the major dimensions have full size, and
219 // the minor dimensions have the provided offsets and sizes.
SliceInMinorDims(Location loc,Value v,ArrayRef<int64_t> minor_starts,ArrayRef<int64_t> minor_limits,OpBuilder * builder)220 static Value SliceInMinorDims(Location loc, Value v,
221 ArrayRef<int64_t> minor_starts,
222 ArrayRef<int64_t> minor_limits,
223 OpBuilder *builder) {
224 auto type = v.getType().cast<RankedTensorType>();
225 llvm::SmallVector<int64_t, 4> slice_starts(type.getRank(), 0);
226 int64_t major_dims = type.getRank() - minor_starts.size();
227 std::copy(minor_starts.begin(), minor_starts.end(),
228 slice_starts.begin() + major_dims);
229 auto slice_limits = llvm::to_vector<4>(type.getShape());
230 std::copy(minor_limits.begin(), minor_limits.end(),
231 slice_limits.begin() + major_dims);
232 llvm::SmallVector<int64_t, 4> slice_strides(type.getRank(), 1);
233 return builder->create<SliceOp>(loc, v,
234 GetI64ElementsAttr(slice_starts, builder),
235 GetI64ElementsAttr(slice_limits, builder),
236 GetI64ElementsAttr(slice_strides, builder));
237 }
238
239 // Creates a vector of index values:
240 // [0, 0, ..., minor_indices[0], minor_indices[1], ... minor_indices[-1]]
241 // with length `rank`.
CreateFullIndexVectorFromMinorIndices(Location loc,ArrayRef<Value> minor_indices,int64_t rank,OpBuilder * builder)242 static llvm::SmallVector<Value, 4> CreateFullIndexVectorFromMinorIndices(
243 Location loc, ArrayRef<Value> minor_indices, int64_t rank,
244 OpBuilder *builder) {
245 auto zero =
246 GetScalarConstOfType(getElementTypeOrSelf(minor_indices[0].getType()),
247 loc, 0, builder)
248 .output();
249 llvm::SmallVector<Value, 4> indices(rank, zero);
250 std::copy(minor_indices.begin(), minor_indices.end(),
251 indices.begin() + (rank - minor_indices.size()));
252 return indices;
253 }
254
255 // Creates an mhlo::DynamicSliceOp where the major dimensions have full size,
256 // and the minor dimensions have the provided offsets and sizes.
DynamicSliceInMinorDims(Location loc,Value v,ArrayRef<Value> minor_starts,ArrayRef<int64_t> minor_sizes,OpBuilder * builder)257 static Value DynamicSliceInMinorDims(Location loc, Value v,
258 ArrayRef<Value> minor_starts,
259 ArrayRef<int64_t> minor_sizes,
260 OpBuilder *builder) {
261 if (minor_starts.empty()) return v;
262 auto type = v.getType().cast<RankedTensorType>();
263 auto slice_starts = CreateFullIndexVectorFromMinorIndices(
264 loc, minor_starts, type.getRank(), builder);
265 int64_t major_dims = type.getRank() - minor_starts.size();
266 auto slice_sizes = llvm::to_vector<4>(type.getShape());
267 std::copy(minor_sizes.begin(), minor_sizes.end(),
268 slice_sizes.begin() + major_dims);
269 return builder->create<mhlo::DynamicSliceOp>(
270 loc, v, slice_starts, GetI64ElementsAttr(slice_sizes, builder));
271 }
272
273 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
274 // offsets, and the minor dimensions have the provided offsets.
DynamicUpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<Value> minor_starts,OpBuilder * builder)275 static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update,
276 ArrayRef<Value> minor_starts,
277 OpBuilder *builder) {
278 if (minor_starts.empty()) return v;
279 auto type = v.getType().cast<RankedTensorType>();
280 auto dus_starts = CreateFullIndexVectorFromMinorIndices(
281 loc, minor_starts, type.getRank(), builder);
282 return builder->create<DynamicUpdateSliceOp>(loc, type, v, update,
283 llvm::makeArrayRef(dus_starts));
284 }
285
286 // Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
287 // offsets, and the minor dimensions have the provided static offsets.
UpdateSliceInMinorDims(Location loc,Value v,Value update,ArrayRef<int64_t> minor_starts,OpBuilder * builder)288 static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
289 ArrayRef<int64_t> minor_starts,
290 OpBuilder *builder) {
291 llvm::SmallVector<Value, 4> dus_starts(minor_starts.size());
292 for (uint64_t i = 0; i < minor_starts.size(); ++i) {
293 dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc,
294 minor_starts[i], builder);
295 }
296 return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder);
297 }
298
299 // Deprecated: This is maintained to aid in porting old code that is not yet
300 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
301 // Gets the resulting type from a broadcast between two types for statically
302 // shaped types. This is to be used for legacy lowerings that both use non
303 // left-padded broadcasting and static shapes. Its use should not be permitted
304 // in new code.
305 // May return nullptr on invalid static broadcast dimensions.
306 // ABSL_DEPRECATED()
GetStaticBroadcastType(RankedTensorType x,RankedTensorType y,DenseIntElementsAttr broadcast_dimensions_attr)307 static RankedTensorType GetStaticBroadcastType(
308 RankedTensorType x, RankedTensorType y,
309 DenseIntElementsAttr broadcast_dimensions_attr) {
310 auto element_type = x.getElementType();
311 auto shape_x = x.getShape();
312 auto shape_y = y.getShape();
313
314 if (shape_x.size() == shape_y.size()) {
315 llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
316 for (int i = 0; i < shape_x.size(); i++) {
317 auto x_val = shape_x[i];
318 auto y_val = shape_y[i];
319 out_shape[i] = std::max(x_val, y_val);
320 }
321 return RankedTensorType::get(out_shape, element_type);
322 }
323
324 auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
325 auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
326
327 llvm::SmallVector<int64_t, 4> broadcast_dimensions;
328 // Explicit broadcast dimensions.
329 for (const APInt &int_value : broadcast_dimensions_attr) {
330 broadcast_dimensions.push_back(int_value.getSExtValue());
331 }
332 if (broadcast_dimensions.size() != shape_small.size()) {
333 return nullptr;
334 }
335 llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
336 shape_large.end());
337
338 // Update according to the broadcast dimensions.
339 for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
340 auto old_value = out_shape[index_pair.value()];
341 auto new_value = shape_small[index_pair.index()];
342 out_shape[index_pair.value()] = std::max(old_value, new_value);
343 }
344 return RankedTensorType::get(out_shape, element_type);
345 }
346
347 // Deprecated: This is maintained to aid in porting old code that is not yet
348 // dynamic shape aware and uses broadcasting modes that CHLO does not support.
349 // Applies static binary broadcasting to a binary elementwise op.
350 // This is a legacy helper to provide general broadcasting support in legacy,
351 // static shaped code that relies on non-left-padded broadcasting semantics.
352 template <typename BinaryOp>
StaticBinaryBroadcast(Location loc,Value x,Value y,DenseIntElementsAttr broadcast_dims,OpBuilder & builder)353 static Value StaticBinaryBroadcast(Location loc, Value x, Value y,
354 DenseIntElementsAttr broadcast_dims,
355 OpBuilder &builder) {
356 auto x_type = x.getType().cast<RankedTensorType>();
357 auto y_type = y.getType().cast<RankedTensorType>();
358 auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims);
359 if (!result_type) {
360 emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type
361 << " with broadcast_dims = " << broadcast_dims;
362 return nullptr;
363 }
364 auto larger_broadcast_dims =
365 GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder);
366 if (x_type.getRank() < y_type.getRank()) {
367 if (x_type != result_type) {
368 x = builder.create<BroadcastInDimOp>(loc, result_type, x, broadcast_dims);
369 }
370 if (y_type != result_type) {
371 y = builder.create<BroadcastInDimOp>(loc, result_type, y,
372 larger_broadcast_dims);
373 }
374 } else {
375 if (x_type != result_type) {
376 x = builder.create<BroadcastInDimOp>(loc, result_type, x,
377 larger_broadcast_dims);
378 }
379 if (y_type != result_type) {
380 y = builder.create<BroadcastInDimOp>(loc, result_type, y, broadcast_dims);
381 }
382 }
383 return builder.create<BinaryOp>(loc, x, y);
384 }
385
386 // Gets a 1D tensor type suitable for expressing extents of the given tensor
387 // value type. If the value type is ranked, the result will be statically
388 // shaped. Otherwise, it will have a dynamic dimension.
GetExtentsTensorTypeFor(TensorType value_type)389 static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) {
390 Builder b(value_type.getContext());
391 int64_t dim = value_type.hasRank() ? value_type.getRank() : -1;
392 return RankedTensorType::get({dim}, b.getIndexType());
393 }
394
395 // Given a value (broadcast_to) and a feature dimension, broadcasts a 1D
396 // value (broadcast_from) along that feature dimension. This is a shortcut
397 // for the cases where a 1D tensor must be broadcast along a specific feature
398 // dimension, which can vary based on data layout, etc.
399 //
400 // The extent of `broadcast_from` dim0 must be equal to the extent of the
401 // feature_dim of `broadcast_to`.
402 //
403 // Example:
404 // [1x2x3x4], [2], 1 -> [1x2x3x4]
405 // TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for
406 // consistency. Possibly also rename for clarity.
Broadcast1DToFeatureDim(Location loc,Value broadcast_to,Value broadcast_from,int64_t feature_dim,OpBuilder & builder)407 static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to,
408 Value broadcast_from, int64_t feature_dim,
409 OpBuilder &builder) {
410 auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder);
411 auto to_type = broadcast_to.getType().cast<RankedTensorType>();
412 auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
413 auto result_extents_type = GetExtentsTensorTypeFor(to_type);
414 auto result_extents = builder.create<shape::ToExtentTensorOp>(
415 loc, result_extents_type, result_shape);
416 return builder.create<DynamicBroadcastInDimOp>(
417 loc, to_type, broadcast_from, result_extents, broadcast_dims);
418 }
419
420 // Broadcasts `input` to the shape of `broadcast_to` value following
421 // TF::BroadcastTo semantics.
422 //
423 // Requires that input is a ranked tensor.
424 //
425 // TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp
426 // supports unranked inputs in the lowering.
BroadcastToShapeOf(Location loc,Value input,Value broadcast_to,OpBuilder & builder)427 static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to,
428 OpBuilder &builder) {
429 auto result_shape = builder.create<shape::ShapeOfOp>(loc, broadcast_to);
430 auto to_type = broadcast_to.getType().cast<TensorType>();
431 auto result_extents_type = GetExtentsTensorTypeFor(to_type);
432 auto result_extents = builder.create<shape::ToExtentTensorOp>(
433 loc, result_extents_type, result_shape);
434 int64_t rank = input.getType().cast<RankedTensorType>().getRank();
435 auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder);
436 return builder.create<DynamicBroadcastInDimOp>(
437 loc, to_type, input, result_extents, broadcast_dims);
438 }
439
440 // Creates a batch dot using mhlo::DotGeneralOp.
BatchDot(Location loc,Value lhs,bool transpose_lhs,Value rhs,bool transpose_rhs,int64_t num_batch_dims,ArrayAttr precision_config,OpBuilder * builder)441 Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
442 bool transpose_rhs, int64_t num_batch_dims,
443 ArrayAttr precision_config, OpBuilder *builder) {
444 auto batch_dimensions =
445 llvm::to_vector<4>(llvm::seq<int64_t>(0, num_batch_dims));
446 auto lhs_contracting_dimensions = llvm::to_vector<1>(llvm::makeArrayRef(
447 {transpose_lhs ? num_batch_dims : num_batch_dims + 1}));
448 auto rhs_contracting_dimensions = llvm::to_vector<1>(llvm::makeArrayRef(
449 {transpose_rhs ? num_batch_dims + 1 : num_batch_dims}));
450 auto dimension_numbers = DotDimensionNumbersAttr::get(
451 builder->getContext(),
452 /*lhs_batching_dimensions=*/batch_dimensions,
453 /*rhs_batching_dimensions=*/batch_dimensions,
454 /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
455 /*rhs_contracting_dimensions=*/rhs_contracting_dimensions);
456 auto lhs_shape = lhs.getType().cast<RankedTensorType>().getShape();
457 auto rhs_shape = rhs.getType().cast<RankedTensorType>().getShape();
458 auto shape = llvm::to_vector<4>(lhs_shape);
459 shape[shape.size() - 2] =
460 transpose_lhs ? lhs_shape.back() : lhs_shape[lhs_shape.size() - 2];
461 shape[shape.size() - 1] =
462 transpose_rhs ? rhs_shape[rhs_shape.size() - 2] : rhs_shape.back();
463 Type element_type = getElementTypeOrSelf(lhs.getType());
464 return builder->create<DotGeneralOp>(
465 loc, RankedTensorType::get(shape, element_type), lhs, rhs,
466 dimension_numbers, precision_config);
467 }
468
469 // Builds a set of operations for applying reduction on the input value. A
470 // tf.sum op is created and will be legalized to tfl ops automatically.
ApplyReduction(Location loc,Value input,DenseIntElementsAttr reduce_dims,OpBuilder * builder)471 static Value ApplyReduction(Location loc, Value input,
472 DenseIntElementsAttr reduce_dims,
473 OpBuilder *builder) {
474 auto reduce_dims_op = builder->create<ConstantOp>(loc, reduce_dims);
475 return builder->create<TF::SumOp>(loc, input, reduce_dims_op,
476 builder->getBoolAttr(false));
477 }
478
479 // Creates a mhlo.rng_uniform op with `builder` to generate `num_elements`
480 // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
CreateRngUniform32(Location loc,int num_elements,int lower_limit,int upper_limit,OpBuilder * builder)481 static mhlo::RngOp CreateRngUniform32(Location loc, int num_elements,
482 int lower_limit, int upper_limit,
483 OpBuilder *builder) {
484 auto shape_tensor = builder->create<mhlo::ConstantOp>(
485 loc, GetI64ElementsAttr({num_elements}, builder));
486
487 auto lower = builder->create<mhlo::ConstantOp>(
488 loc, builder->getI32IntegerAttr(lower_limit));
489 auto upper = builder->create<mhlo::ConstantOp>(
490 loc, builder->getI32IntegerAttr(upper_limit));
491
492 return builder->create<mhlo::RngOp>(loc, lower, upper, shape_tensor,
493 ::mlir::mhlo::RngDistribution::UNIFORM);
494 }
495
496 using WhileBodyFnType = llvm::function_ref<void(
497 Location loc, Value iteration, ArrayRef<Value> old_values,
498 SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
499
500 // Creates a mhlo.while op with `builder` to loop `num_interations` times,
501 // each time calling the given `body_fn` on a set of values to generate a new
502 // set of values. Returns the final set of values via `final_values`. The
503 // initial set of values is passed in via `init_values`.
504 //
505 // This effectively does:
506 //
507 // ```c++
508 // SmallVector<Values, 4> old_values = init_values;
509 // SmallVector<Values, 4> new_values;
510 // for (int i = 0; i < num_iterations; ++i) {
511 // body_fn(old_values, &new_values, ...);
512 // old_values = new_values;
513 // }
514 // ```
515 //
516 // Under the hood an induction variable is prepended to values to control the
517 // number of iterations, but that is transparent to `body_fn`, which does not
518 // need to care about that.
CreateWhile32(Location loc,int num_iterations,WhileBodyFnType body_fn,ArrayRef<Value> init_values,SmallVectorImpl<Value> * final_values,OpBuilder * builder)519 static void CreateWhile32(Location loc, int num_iterations,
520 WhileBodyFnType body_fn, ArrayRef<Value> init_values,
521 SmallVectorImpl<Value> *final_values,
522 OpBuilder *builder) {
523 int value_count = init_values.size() + 1;
524
525 // Prepend a loop induction variable to the initial values.
526 SmallVector<Value, 2> init_values_with_loop_iv;
527 SmallVector<Type, 2> init_types_with_loop_iv;
528 init_values_with_loop_iv.reserve(value_count);
529 init_types_with_loop_iv.reserve(value_count);
530
531 // The initial value for the loop induction variable is 0.
532 init_values_with_loop_iv.push_back(
533 builder->create<mhlo::ConstantOp>(loc, builder->getI32IntegerAttr(0)));
534 init_values_with_loop_iv.append(init_values.begin(), init_values.end());
535
536 // Accumulate types of all the init values.
537 for (const auto &init_value_with_loop_iv : init_values_with_loop_iv)
538 init_types_with_loop_iv.push_back(init_value_with_loop_iv.getType());
539
540 // Create the while op.
541 auto while_op = builder->create<mhlo::WhileOp>(loc, init_types_with_loop_iv,
542 init_values_with_loop_iv);
543 auto ivs_count = init_types_with_loop_iv.size();
544
545 {
546 OpBuilder::InsertionGuard guard(*builder);
547
548 // Build up the only block in the condition region.
549 Region &condition = while_op.cond();
550 Block *block = builder->createBlock(&condition);
551 block->addArguments(init_types_with_loop_iv,
552 SmallVector<Location>(ivs_count, loc));
553
554 // Get the loop induction variable and compare it against the upper limit.
555 auto loop_iv = block->getArgument(0);
556 auto upper_limit = builder->create<mhlo::ConstantOp>(
557 loc, builder->getI32IntegerAttr(num_iterations));
558 Value compare = builder->create<mhlo::CompareOp>(loc, loop_iv, upper_limit,
559 ComparisonDirection::LT);
560
561 builder->create<mhlo::ReturnOp>(loc, compare);
562 }
563
564 {
565 OpBuilder::InsertionGuard guard(*builder);
566
567 // Build up the only block in the body region.
568 Region &body = while_op.body();
569 Block *block = builder->createBlock(&body);
570 block->addArguments(init_types_with_loop_iv,
571 SmallVector<Location>(ivs_count, loc));
572
573 SmallVector<Value, 4> new_values; // Generated by this iteration
574 new_values.reserve(value_count);
575
576 // Feed all values excluding the loop induction variable to body_fn.
577 body_fn(loc, block->getArgument(0),
578 ArrayRef<Value>(block->getArguments().begin() + 1,
579 block->getArguments().end()),
580 &new_values, builder);
581
582 // Increment the loop induction variable by one.
583 auto one =
584 builder->create<mhlo::ConstantOp>(loc, builder->getI32IntegerAttr(1));
585 auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
586 auto plus_one = builder->create<chlo::BroadcastAddOp>(
587 loc, block->getArgument(0), one, scalar_broadcast_dims);
588 // Prepend with the updated loop induction variable.
589 new_values.insert(new_values.begin(), plus_one);
590
591 builder->create<mhlo::ReturnOp>(loc, new_values);
592 }
593
594 // TODO(jpienaar): Support multi-operand while op.
595 final_values->reserve(init_values.size());
596 for (int i = 0, e = init_values.size(); i < e; ++i)
597 final_values->push_back(while_op.getResult(i + 1));
598 }
599
600 //===----------------------------------------------------------------------===//
601 // BatchNorm op utilities.
602 //===----------------------------------------------------------------------===//
603
getFeatureDimensionAttr(Builder & b,tensorflow::TensorFormat format,Value input)604 static IntegerAttr getFeatureDimensionAttr(Builder &b,
605 tensorflow::TensorFormat format,
606 Value input) {
607 return b.getI64IntegerAttr(
608 GetFeatureDimension(format, input.getType().cast<RankedTensorType>()));
609 }
610
611 //===----------------------------------------------------------------------===//
612 // FFT op utilities.
613 //===----------------------------------------------------------------------===//
614
615 // Returns the 1D i64 elements attribute populated with the inner-most dim of
616 // the value.
GetInnerDimFromValue(ShapedType type,Builder * builder)617 static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type,
618 Builder *builder) {
619 if (type.getRank() == 0) {
620 return builder->getI64TensorAttr({});
621 }
622 return builder->getI64TensorAttr(type.getShape().back());
623 }
624
625 // Returns True if the inner-most dim is static.
CheckInnerDimStatic(ShapedType type,Builder * builder)626 bool CheckInnerDimStatic(ShapedType type, Builder *builder) {
627 if (!type.hasRank()) {
628 return false;
629 }
630 return !type.isDynamicDim(type.getShape().size() - 1);
631 }
632
633 //===----------------------------------------------------------------------===//
634 // MatMul op utilities.
635 //===----------------------------------------------------------------------===//
636
637 // If the 'transpose' attribute is true returns ElementsAttr to transpose 2D
638 // matrix. Otherwise, returns ElementsAttr for identity transpose.
Get2DTransposePerm(BoolAttr transpose,Builder * b)639 static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) {
640 if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b);
641 return GetI64ElementsAttr({0, 1}, b);
642 }
643
644 //===----------------------------------------------------------------------===//
645 // MatrixBandPart op utilities.
646 //===----------------------------------------------------------------------===//
647
648 // Gets the size of the dimension `dim_from_end` from the end of `input`.
649 // Requires that `input` is a tensor.
GetDimensionSizeFromEnd(Value input,int dim_from_end)650 static int GetDimensionSizeFromEnd(Value input, int dim_from_end) {
651 // Note: the verifier enforces that `input` is a ranked tensor.
652 auto input_type = input.getType().cast<TensorType>();
653 auto input_shape = input_type.getShape();
654 int dim = (input_shape.size() - 1) - dim_from_end;
655 return input_shape[dim];
656 }
657
658 // Gets a 2D tensor type with shape {dim_0, dim_1}, where `dim_0` and `dim_1`
659 // have the same size as the last two dimensions of `input` (the second-to-last
660 // dimension and last dimension, respectively). The element type of the
661 // outputted RankedTensorType will match the element type of `input`.
662 // Requires that `input` is a tensor.
Get2DTensorType(Value input,Value num_lower)663 static RankedTensorType Get2DTensorType(Value input, Value num_lower) {
664 // `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last.
665 int dim_0 = GetDimensionSizeFromEnd(input, 1);
666 int dim_1 = GetDimensionSizeFromEnd(input, 0);
667 auto element_type = num_lower.getType().cast<TensorType>().getElementType();
668 return RankedTensorType::get({dim_0, dim_1}, element_type);
669 }
670
671 // Creates a HLO ConvertOp, converting `input` to have the same element type as
672 // `elem_type_tensor`. Requires `elem_type_tensor` to be a tensor.
CreateConvertOp(OpBuilder * builder,Location loc,Value input,Value elem_type_tensor)673 static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input,
674 Value elem_type_tensor) {
675 auto element_type =
676 elem_type_tensor.getType().cast<TensorType>().getElementType();
677 return builder->create<mhlo::ConvertOp>(loc, input, element_type);
678 }
679
680 //===----------------------------------------------------------------------===//
681 // Pad op utilities.
682 //===----------------------------------------------------------------------===//
683
684 // Slices input attribute of rank two and returns the specified column.
685 //
686 // Always returns 64 bit integer attribute regardless of bitwidth of the input
687 // attribute.
SliceDenseIntElementsAttrColumn2D(ElementsAttr input,int column)688 static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
689 ElementsAttr input, int column) {
690 auto int_attr = input.cast<DenseIntElementsAttr>();
691 auto shaped_type = int_attr.getType();
692 auto shape = shaped_type.getShape();
693
694 if (shape.size() != 2) return DenseIntElementsAttr();
695
696 llvm::SmallVector<int64_t, 4> values;
697 values.reserve(shaped_type.getNumElements() / shape[1]);
698
699 for (auto it : llvm::enumerate(int_attr.getValues<APInt>())) {
700 if (static_cast<int>(it.index() % shape[1]) == column) {
701 values.push_back(it.value().getSExtValue());
702 }
703 }
704
705 auto element_type = IntegerType::get(input.getContext(), 64);
706 return DenseIntElementsAttr::get(
707 RankedTensorType::get({shape[0]}, element_type), values);
708 }
709
710 // Returns interior padding to use in HLO Pad op based on the TensorFlow padding
711 // in TensorFlow PadV2 op.
GetInteriorPadding(ElementsAttr tf_padding)712 static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
713 auto length = tf_padding.getType().getShape()[0];
714 auto element_type = IntegerType::get(tf_padding.getContext(), 64);
715 return DenseIntElementsAttr::get<int64_t>(
716 RankedTensorType::get({length}, element_type), 0);
717 }
718
719 //===----------------------------------------------------------------------===//
720 // Binary op utilities.
721 //===----------------------------------------------------------------------===//
722
723 // Returns whether the two values are guaranteed to be broadcastable to the
724 // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions
725 // must be broadcasted with a size 1 tensor or another dynamic dimension.
726 // Returns false on rankless.
AreBroadcastCompatible(Value x,Value y)727 static bool AreBroadcastCompatible(Value x, Value y) {
728 auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
729 auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
730 if (!x_rankless || !y_rankless) {
731 return false;
732 }
733
734 // Check that the shapes can be broadcasted.
735 auto shape_x = x_rankless.getShape();
736 auto shape_y = y_rankless.getShape();
737
738 int rank_diff = shape_x.size() - shape_y.size();
739 int offset_x = rank_diff > 0 ? rank_diff : 0;
740 int offset_y = rank_diff < 0 ? -rank_diff : 0;
741 for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) {
742 int index_x = i + offset_x;
743 int index_y = i + offset_y;
744 if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) ||
745 (shape_y[index_y] == -1 && shape_x[index_x] != 1)) {
746 return false;
747 }
748 }
749
750 return true;
751 }
752
753 // Return a new TensorType the same rank and dimensions as the input with an
754 // updated element type.
ChangeTensorElementType(Builder * b,Type tensor_type,Type element_type)755 static Type ChangeTensorElementType(Builder *b, Type tensor_type,
756 Type element_type) {
757 RankedTensorType ranked_type = tensor_type.dyn_cast<RankedTensorType>();
758 if (ranked_type) {
759 return RankedTensorType::get(ranked_type.getShape(), element_type);
760 }
761
762 return UnrankedTensorType::get(element_type);
763 }
764
765 //===----------------------------------------------------------------------===//
766 // Softmax op utilities.
767 //===----------------------------------------------------------------------===//
768
769 // Returns the type to use for accumulating the given type.
GetAccumulationType(Type ty)770 static Type GetAccumulationType(Type ty) {
771 // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
772 // repeated floating point additions.
773 return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
774 }
775
776 //===----------------------------------------------------------------------===//
777 // Softplus op utilities.
778 //===----------------------------------------------------------------------===//
779
GetEpsilonValue(Type ty)780 static DenseElementsAttr GetEpsilonValue(Type ty) {
781 auto element_ty = ty.cast<TensorType>().getElementType();
782 auto scalar_ty = RankedTensorType::get({}, element_ty);
783 if (element_ty.isF16()) {
784 uint16_t raw_epsilon = Eigen::numext::bit_cast<uint16_t>(
785 Eigen::NumTraits<Eigen::half>::epsilon());
786 auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
787 return DenseElementsAttr::get(scalar_ty, value);
788 } else if (element_ty.isBF16()) {
789 uint16_t raw_epsilon = Eigen::numext::bit_cast<uint16_t>(
790 Eigen::NumTraits<Eigen::bfloat16>::epsilon());
791 auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
792 return DenseElementsAttr::get(scalar_ty, value);
793 } else if (element_ty.isF32()) {
794 auto value = APFloat(std::numeric_limits<float>::epsilon());
795 return DenseElementsAttr::get(scalar_ty, value);
796 } else if (element_ty.isF64()) {
797 auto value = APFloat(std::numeric_limits<double>::epsilon());
798 return DenseElementsAttr::get(scalar_ty, value);
799 }
800 llvm_unreachable("unsupported element type for tf.SoftPlus");
801 }
802
803 //===----------------------------------------------------------------------===//
804 // ArgMax/ArgMin op utilities.
805 //===----------------------------------------------------------------------===//
806
BuildArgMinMaxReductionBody(Type input_element_type,Type index_element_type,ComparisonDirection direction,Region * body,OpBuilder * builder)807 static void BuildArgMinMaxReductionBody(Type input_element_type,
808 Type index_element_type,
809 ComparisonDirection direction,
810 Region *body, OpBuilder *builder) {
811 OpBuilder::InsertionGuard insertion_point_gurad(*builder);
812
813 Type input_type = RankedTensorType::get(/*shape=*/{}, input_element_type);
814 Type index_type = RankedTensorType::get(/*shape=*/{}, index_element_type);
815 Block *block = builder->createBlock(body);
816 Location loc = body->getLoc();
817 block->addArguments({input_type, index_type, input_type, index_type},
818 SmallVector<Location, 4>(4, loc));
819
820 Value lhs_val = block->getArgument(0);
821 Value lhs_index = block->getArgument(1);
822 Value rhs_val = block->getArgument(2);
823 Value rhs_index = block->getArgument(3);
824
825 ImplicitLocOpBuilder b(loc, *builder);
826 Value compare_dt = b.create<CompareOp>(lhs_val, rhs_val, direction);
827 Value selected_input =
828 b.create<SelectOp>(input_type, compare_dt, lhs_val, rhs_val);
829
830 Value compare_eq =
831 b.create<CompareOp>(lhs_val, rhs_val, ComparisonDirection::EQ);
832 Value min_index = b.create<MinOp>(lhs_index, rhs_index);
833 Value min_val_index =
834 b.create<SelectOp>(index_type, compare_dt, lhs_index, rhs_index);
835 Value selected_index =
836 b.create<SelectOp>(index_type, compare_eq, min_index, min_val_index);
837
838 Value return_values[] = {selected_input, selected_index};
839 b.create<ReturnOp>(return_values);
840 }
841
842 //===----------------------------------------------------------------------===//
843 // PartitionedCall op utilities.
844 //===----------------------------------------------------------------------===//
845
846 // Verify that the arguments to be passed into the function are the same types
847 // as the function paramter types.
ArgTypesMatchCallee(mlir::Operation * op,OperandRange args,SymbolRefAttr func)848 static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args,
849 SymbolRefAttr func) {
850 auto module = op->getParentOfType<ModuleOp>();
851 auto function =
852 dyn_cast_or_null<func::FuncOp>(SymbolTable::lookupSymbolIn(module, func));
853 FunctionType function_ty = function.getFunctionType();
854
855 for (auto arg_in : llvm::zip(args, function_ty.getInputs())) {
856 if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) {
857 // Argument type and input type mismatch.
858 return false;
859 }
860 }
861 return true;
862 }
863
864 //===----------------------------------------------------------------------===//
865 // Slice op utilities.
866 //===----------------------------------------------------------------------===//
867
CanBeTranslatedToDynamicSlice(Value input,Value start_indices,DenseIntElementsAttr slice_sizes)868 static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
869 DenseIntElementsAttr slice_sizes) {
870 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
871 if (!input_ty) return false;
872 auto start_indices_ty = start_indices.getType().dyn_cast<RankedTensorType>();
873 if (!start_indices_ty) return false;
874
875 int64_t input_rank = input_ty.getRank();
876 ArrayRef<int64_t> input_shape = input_ty.getShape();
877 DenseIntElementsAttr constant_start_indices;
878 bool is_constant_start =
879 matchPattern(start_indices, m_Constant(&constant_start_indices));
880
881 for (int64_t i = 0; i < input_rank; ++i) {
882 int64_t input_size = input_shape[i];
883 int64_t slice_size = slice_sizes.getValues<IntegerAttr>()[i].getInt();
884 // A slice_size of -1 means "all elements from start_index to the end".
885 // In order to support these semantics, we need to know both the start index
886 // and the shape of the input dimension.
887 if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false;
888 }
889 return true;
890 }
891
892 // TF slice size can be -1, which represents all elements from start_index to
893 // the end. HLO slice size can't be -1. As such, we need to translate TF slice
894 // size -1 to HLO slice size.
TFSliceSizes2HLOSliceSizes(Value input,Value start_indices,DenseIntElementsAttr slice_sizes,Builder * builder)895 static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
896 Value input, Value start_indices, DenseIntElementsAttr slice_sizes,
897 Builder *builder) {
898 DenseIntElementsAttr constant_start_indices;
899 if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) {
900 return hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))
901 .cast<DenseIntElementsAttr>();
902 }
903
904 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
905 int64_t input_rank = input_ty.getRank();
906 ArrayRef<int64_t> input_shape = input_ty.getShape();
907 SmallVector<int64_t, 4> normalized_sizes;
908
909 for (int64_t i = 0; i < input_rank; ++i) {
910 int64_t input_size = input_shape[i];
911 int64_t start_index =
912 constant_start_indices.getValues<IntegerAttr>()[i].getInt();
913 int64_t slice_size = slice_sizes.getValues<IntegerAttr>()[i].getInt();
914 normalized_sizes.push_back(slice_size == -1 ? input_size - start_index
915 : slice_size);
916 }
917
918 return GetI64ElementsAttr(normalized_sizes, builder);
919 }
920
921 //===----------------------------------------------------------------------===//
922 // XlaGather op utilities.
923 //===----------------------------------------------------------------------===//
924
HasValidGatherDims(StringAttr attr)925 bool HasValidGatherDims(StringAttr attr) {
926 ::xla::GatherDimensionNumbers dims;
927 return dims.ParseFromString(attr.getValue().str());
928 }
929
GetGatherDimNumsAttr(StringAttr attr,Builder * builder)930 GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr,
931 Builder *builder) {
932 ::xla::GatherDimensionNumbers dims;
933 if (!dims.ParseFromString(attr.getValue().str())) return {};
934 return ::xla::ConvertGatherDimensionNumbers(dims, builder);
935 }
936
937 //===----------------------------------------------------------------------===//
938 // XlaDot op utilities.
939 //===----------------------------------------------------------------------===//
940
HasValidDotDims(StringAttr attr)941 bool HasValidDotDims(StringAttr attr) {
942 ::xla::DotDimensionNumbers dims;
943 return dims.ParseFromString(attr.getValue().str());
944 }
945
GetDotDimNumsAttr(StringAttr attr,Builder * builder)946 DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, Builder *builder) {
947 ::xla::DotDimensionNumbers dims;
948 if (!dims.ParseFromString(attr.getValue().str())) return {};
949 return ::xla::ConvertDotDimensionNumbers(dims, builder);
950 }
951
HasValidPrecisionConfig(StringAttr attr)952 bool HasValidPrecisionConfig(StringAttr attr) {
953 ::xla::PrecisionConfig precision;
954 return precision.ParseFromString(attr.getValue().str());
955 }
956
GetPrecisionConfigAttr(StringAttr attr,Builder * builder)957 mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) {
958 ::xla::PrecisionConfig precision;
959 if (!precision.ParseFromString(attr.getValue().str())) return {};
960 return ::xla::ConvertPrecisionConfig(&precision, builder);
961 }
962
963 //===----------------------------------------------------------------------===//
964 // XlaVariadicReduceV2 op utilities.
965 //===----------------------------------------------------------------------===//
966
BuildBodyWithCall(PatternRewriter & rewriter,const Location & loc,mlir::SymbolRefAttr func,mlir::FunctionType func_ty,Region * body)967 static void BuildBodyWithCall(PatternRewriter &rewriter, const Location &loc,
968 mlir::SymbolRefAttr func,
969 mlir::FunctionType func_ty, Region *body) {
970 OpBuilder::InsertionGuard guard(rewriter);
971
972 Block *block = rewriter.createBlock(body);
973 auto inputs = func_ty.getInputs();
974 block->addArguments(inputs, SmallVector<Location>(inputs.size(), loc));
975 mlir::func::CallOp call_op = rewriter.create<mlir::func::CallOp>(
976 loc, func, func_ty.getResults(), block->getArguments());
977 rewriter.create<mhlo::ReturnOp>(loc, call_op.getResults());
978 }
979
980 //===----------------------------------------------------------------------===//
981 // Op converters.
982 //===----------------------------------------------------------------------===//
983
GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dims,tensorflow::TensorFormat format,Builder * builder)984 NamedAttribute GetConvDimensionNumbersAttr(ArrayRef<int64_t> spatial_dims,
985 tensorflow::TensorFormat format,
986 Builder *builder) {
987 int64_t num_spatial_dims = spatial_dims.size();
988 int64_t num_dims = num_spatial_dims + 2;
989
990 int64_t batch_dim = GetTensorBatchDimIndex(num_dims, format);
991 int64_t feature_dim = GetTensorFeatureDimIndex(num_dims, format);
992
993 // Filters data_format is always HWIO so input channels dimension is after
994 // all spatial dimensions.
995 int64_t kernel_input_feature_dim = num_spatial_dims;
996 int64_t kernel_output_feature_dim = num_spatial_dims + 1;
997 SmallVector<int64_t, 4> kernel_spatial_dimensions;
998 kernel_spatial_dimensions.resize(num_spatial_dims);
999 std::iota(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(),
1000 0);
1001
1002 return builder->getNamedAttr(
1003 "dimension_numbers",
1004 ConvDimensionNumbersAttr::get(
1005 builder->getContext(), batch_dim, feature_dim, spatial_dims,
1006 kernel_input_feature_dim, kernel_output_feature_dim,
1007 kernel_spatial_dimensions, batch_dim, feature_dim, spatial_dims));
1008 }
1009
1010 // Converts a TF::BiasAddOp to HLO.
1011 // This differs from a normal TF::AddOp with respect to how the data_format
1012 // is handled, which can optionally require a general broadcast of the
1013 // 'bias' term in a way that is not compatible with the standard left-padded
1014 // broadcast semantics (i.e. NCHW will broadcast into dimension 1).
1015 // The correct 'bias' broadcast will be synthesized manually.
1016 class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
1017 public:
1018 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::BiasAddOp op,PatternRewriter & rewriter) const1019 LogicalResult matchAndRewrite(TF::BiasAddOp op,
1020 PatternRewriter &rewriter) const override {
1021 Location loc = op.getLoc();
1022 tensorflow::TensorFormat data_format;
1023 if (!FormatFromString(op.data_format().str(), &data_format))
1024 return op.emitOpError("invalid data format");
1025
1026 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
1027 if (!value_type) return failure();
1028 auto feature_dim = GetFeatureDimension(data_format, value_type);
1029 auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(),
1030 feature_dim, rewriter);
1031 Value add = rewriter.create<AddOp>(loc, op.value(), bias_broadcast);
1032 if (add.getType() != op.getType()) {
1033 add = rewriter.create<tensor::CastOp>(loc, op.getType(), add);
1034 }
1035 rewriter.replaceOp(op, {add});
1036 return success();
1037 }
1038 };
1039
1040 // Conterts tf.Conv2D to mhlo.dynamic_conv.
1041 // TODO(disc): To recover static special case's performance with adding folding,
1042 // canonicalization func and removing ConvertConvOp.
1043 template <typename OpT, int num_spatial_dims, bool depthwise_conv = false>
1044 class ConvertConvDynamic : public OpRewritePattern<OpT> {
1045 public:
1046 using OpRewritePattern<OpT>::OpRewritePattern;
1047
GetPaddingValues(OpT & op,PatternRewriter & rewriter,Value input_size,Value filter_size,int64_t dilation_rate,int64_t stride,tensorflow::Padding padding_type,Type shape_scalar_type,Value * padding_low,Value * padding_high) const1048 bool GetPaddingValues(OpT &op, PatternRewriter &rewriter, Value input_size,
1049 Value filter_size, int64_t dilation_rate,
1050 int64_t stride, tensorflow::Padding padding_type,
1051 Type shape_scalar_type, Value *padding_low,
1052 Value *padding_high) const {
1053 // Stride must be > 0
1054 if (stride <= 0) return false;
1055 // Dilation rate must be >= 1
1056 if (dilation_rate < 1) return false;
1057
1058 Location loc = op.getLoc();
1059 switch (padding_type) {
1060 case tensorflow::Padding::VALID: {
1061 auto zero =
1062 rewriter.create<arith::ConstantIntOp>(loc, 0, shape_scalar_type);
1063 *padding_low = *padding_high = zero;
1064 break;
1065 }
1066 case tensorflow::Padding::EXPLICIT:
1067 break;
1068 case tensorflow::Padding::SAME: {
1069 auto zero =
1070 rewriter.create<arith::ConstantIntOp>(loc, 0, shape_scalar_type);
1071 auto one =
1072 rewriter.create<arith::ConstantIntOp>(loc, 1, shape_scalar_type);
1073 auto two =
1074 rewriter.create<arith::ConstantIntOp>(loc, 2, shape_scalar_type);
1075 // See also the parallel implementation in
1076 // GetWindowedOutputSizeFromDimsV2. effective_filter_size = (filter_size
1077 // - 1) * dilation_rate + 1
1078 Value stride_value = rewriter.create<arith::ConstantIntOp>(
1079 loc, stride, shape_scalar_type);
1080 Value dilation_rate_value = rewriter.create<arith::ConstantIntOp>(
1081 loc, dilation_rate, shape_scalar_type);
1082 Value effective_filter_size_op = rewriter.create<arith::AddIOp>(
1083 loc, one,
1084 rewriter.create<arith::MulIOp>(
1085 loc, dilation_rate_value,
1086 rewriter.create<arith::SubIOp>(loc, filter_size, one)));
1087 // output_size = (input_size + stride - 1) / stride;
1088 Value output_size = rewriter.create<arith::DivUIOp>(
1089 loc,
1090 rewriter.create<arith::AddIOp>(
1091 loc, input_size,
1092 rewriter.create<arith::SubIOp>(loc, stride_value, one)),
1093 stride_value);
1094 // std::max(int64{0}, (output_size - 1) * stride +
1095 // effective_filter_size - input_size);
1096 Value padding_needed = rewriter.create<arith::SubIOp>(
1097 loc,
1098 rewriter.create<arith::AddIOp>(
1099 loc, effective_filter_size_op,
1100 rewriter.create<arith::MulIOp>(
1101 loc, stride_value,
1102 rewriter.create<arith::SubIOp>(loc, output_size, one))),
1103 input_size);
1104 Value cond = rewriter.create<mlir::arith::CmpIOp>(
1105 loc, arith::CmpIPredicate::sge, padding_needed, zero);
1106 padding_needed = rewriter.create<mlir::arith::SelectOp>(
1107 loc, padding_needed.getType(), cond, padding_needed, zero);
1108 *padding_low =
1109 rewriter.create<arith::DivUIOp>(loc, padding_needed, two);
1110 *padding_high =
1111 rewriter.create<arith::SubIOp>(loc, padding_needed, *padding_low);
1112 break;
1113 }
1114 }
1115 return true;
1116 }
1117
matchAndRewriteDynamicConv(OpT op,PatternRewriter & rewriter) const1118 LogicalResult matchAndRewriteDynamicConv(OpT op,
1119 PatternRewriter &rewriter) const {
1120 tensorflow::TensorFormat data_format;
1121 if (!FormatFromString(op.data_format().str(), &data_format))
1122 return op.emitOpError("invalid data format");
1123
1124 tensorflow::Padding padding;
1125 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1126 return failure();
1127
1128 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1129 auto filter_ty =
1130 op.filter().getType().template dyn_cast<RankedTensorType>();
1131 auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
1132 if (!input_ty || !filter_ty || !result_ty) return failure();
1133 // TODO(disc): Remove this constraint once fold and canonicalization
1134 // implemented.
1135 if (input_ty.hasStaticShape() && filter_ty.hasStaticShape())
1136 return failure();
1137
1138 ArrayRef<Attribute> dilations = op.dilations().getValue();
1139 ArrayRef<Attribute> strides = op.strides().getValue();
1140 ArrayRef<Attribute> explicit_paddings;
1141 if (padding == tensorflow::Padding::EXPLICIT) {
1142 // EXPLICIT padding mode and the associated attribute is attached to
1143 // Conv2D.
1144 explicit_paddings =
1145 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1146 }
1147
1148 SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1149 SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1150 SmallVector<int64_t, num_spatial_dims> window_strides;
1151 SmallVector<Value, num_spatial_dims * 2> paddings;
1152
1153 auto get_int = [](Attribute attr) {
1154 return attr.template cast<IntegerAttr>().getInt();
1155 };
1156
1157 constexpr int num_dims = num_spatial_dims + 2;
1158
1159 Location loc = op.getLoc();
1160 auto shape_scalar_type = rewriter.getIntegerType(32);
1161
1162 auto get_const = [&](int64_t val) {
1163 return rewriter.create<mlir::arith::ConstantIntOp>(loc, val,
1164 shape_scalar_type);
1165 };
1166 auto get_dim_value = [&](Value val, int64_t dim) {
1167 Value dim_value = rewriter.create<tensor::DimOp>(loc, val, dim);
1168 return rewriter.create<arith::IndexCastOp>(loc, shape_scalar_type,
1169 dim_value);
1170 };
1171
1172 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1173 const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1174 spatial_dim_indices.push_back(dim);
1175
1176 const int64_t dilation = get_int(dilations[dim]);
1177 rhs_dilations.push_back(dilation);
1178 const int64_t stride = get_int(strides[dim]);
1179 window_strides.push_back(stride);
1180
1181 Value pad_low, pad_high;
1182 if (padding == tensorflow::Padding::EXPLICIT) {
1183 pad_low = get_const(get_int(explicit_paddings[2 * dim]));
1184 pad_high = get_const(get_int(explicit_paddings[2 * dim + 1]));
1185 } else {
1186 auto input_size = get_dim_value(op.input(), dim);
1187 auto filter_size = get_dim_value(op.filter(), i);
1188 if (!GetPaddingValues(op, rewriter, input_size, filter_size, dilation,
1189 stride, padding, shape_scalar_type, &pad_low,
1190 &pad_high)) {
1191 return failure();
1192 }
1193 }
1194 paddings.push_back(pad_low);
1195 paddings.push_back(pad_high);
1196 }
1197 auto rhs_dilations_attr = rewriter.getNamedAttr(
1198 "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1199
1200 auto window_strides_attr = rewriter.getNamedAttr(
1201 "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1202
1203 auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1204 spatial_dim_indices, data_format, &rewriter);
1205
1206 const int64_t input_channels =
1207 GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1208 // Filters data_format is always HWIO so input channels dimension is after
1209 // all spatial dimensions.
1210 const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1211 // TensorFlow convolution op verifies that the number of input channels is
1212 // divisible by the number of filter channels.
1213 // For depthwise convolution the feature_group_count argument would be set
1214 // to the input feature dimension.
1215 const int64_t feature_group_count =
1216 depthwise_conv ? input_channels : input_channels / filter_channels;
1217 auto feature_group_count_attr = rewriter.getNamedAttr(
1218 "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1219
1220 auto batch_group_count_attr = rewriter.getNamedAttr(
1221 "batch_group_count", rewriter.getI64IntegerAttr(1));
1222
1223 Value paddings_op = rewriter.create<tensor::FromElementsOp>(
1224 op.getLoc(),
1225 RankedTensorType::get(2 * num_spatial_dims, rewriter.getI32Type()),
1226 paddings);
1227
1228 SmallVector<Value, 3> operands(op.getOperands());
1229 operands.push_back(paddings_op);
1230 // Reshape the filter to {spatial_dims...., 1,in_channels *
1231 // channel_multiplier}
1232 if (depthwise_conv) {
1233 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1234 llvm::SmallVector<int64_t, num_dims> new_shape(
1235 filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1236 new_shape.push_back(1);
1237 new_shape.push_back(filter_shape[num_spatial_dims] *
1238 filter_shape[num_spatial_dims + 1]);
1239 operands[1] = rewriter.create<mhlo::ReshapeOp>(
1240 op.getLoc(),
1241 RankedTensorType::get(new_shape, filter_ty.getElementType()),
1242 operands[1]);
1243 }
1244 NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
1245 dimension_numbers_attr, feature_group_count_attr,
1246 batch_group_count_attr};
1247 rewriter.replaceOpWithNewOp<mhlo::DynamicConvOp>(op, op.getType(), operands,
1248 llvm::makeArrayRef(attrs));
1249 return success();
1250 }
1251
matchAndRewrite(OpT op,PatternRewriter & rewriter) const1252 LogicalResult matchAndRewrite(OpT op,
1253 PatternRewriter &rewriter) const override {
1254 return matchAndRewriteDynamicConv(op, rewriter);
1255 }
1256 };
1257
1258 using ConvertConv2DDynamic =
1259 ConvertConvDynamic<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1260
1261 // Converts the TensorFlow conv op in template to the generic HLO conv op by
1262 // converting TensorFlow op attributes to HLO op attributes.
1263 //
1264 // Sample result for Conv2D:
1265 //
1266 // %conv = "mhlo.convolution"(%input, %filter) {
1267 // strides = [1, 2],
1268 // paddings = [[1, 0], [1, 1]],
1269 // ...
1270 // }
1271 //
1272 // This pattern is not defined using declarative rewrite rules as computation of
1273 // the paddings attribute anyway requires multiple source op attributes and
1274 // result op attributes. Defining it as declarative rewrite rule will introduce
1275 // some duplication in the C++ helper methods.
1276 template <typename OpTy, int num_spatial_dims, bool depthwise_conv = false>
1277 class ConvertConvOp : public OpRewritePattern<OpTy> {
1278 public:
1279 using OpRewritePattern<OpTy>::OpRewritePattern;
1280
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1281 LogicalResult matchAndRewrite(OpTy op,
1282 PatternRewriter &rewriter) const override {
1283 tensorflow::TensorFormat data_format;
1284 if (!FormatFromString(op.data_format().str(), &data_format))
1285 return op.emitOpError("invalid data format");
1286
1287 tensorflow::Padding padding;
1288 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
1289 return failure();
1290
1291 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
1292 auto filter_ty =
1293 op.filter().getType().template dyn_cast<RankedTensorType>();
1294
1295 // With the exception of input's batch dimension, input and filter need to
1296 // have static shape for calculation of HLO paddings and feature group count
1297 // attributes. Filter is validated here, input is mostly validated at use.
1298 if (!input_ty || !filter_ty || !filter_ty.hasStaticShape())
1299 return failure();
1300
1301 ArrayRef<Attribute> dilations = op.dilations().getValue();
1302 ArrayRef<Attribute> strides = op.strides().getValue();
1303 ArrayRef<Attribute> explicit_paddings;
1304 if (padding == tensorflow::Padding::EXPLICIT) {
1305 // EXPLICIT padding mode and the associated attribute is limited to
1306 // Conv2D. So, fetch attribute by identifier instead of the
1307 // op.explicit_paddings() attribute getter.
1308 explicit_paddings =
1309 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
1310 }
1311
1312 SmallVector<int64_t, num_spatial_dims> spatial_dim_indices;
1313 SmallVector<int64_t, num_spatial_dims> rhs_dilations;
1314 SmallVector<int64_t, num_spatial_dims> window_strides;
1315 SmallVector<int64_t, num_spatial_dims * 2> paddings;
1316
1317 auto get_int = [](Attribute attr) {
1318 return attr.template cast<IntegerAttr>().getInt();
1319 };
1320
1321 constexpr int num_dims = num_spatial_dims + 2;
1322 for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
1323 const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
1324 spatial_dim_indices.push_back(dim);
1325
1326 const int64_t dilation = get_int(dilations[dim]);
1327 rhs_dilations.push_back(dilation);
1328 const int64_t stride = get_int(strides[dim]);
1329 window_strides.push_back(stride);
1330
1331 int64_t pad_low, pad_high;
1332 if (padding == tensorflow::Padding::EXPLICIT) {
1333 pad_low = get_int(explicit_paddings[2 * dim]);
1334 pad_high = get_int(explicit_paddings[2 * dim + 1]);
1335 } else {
1336 int64_t output_size;
1337 int64_t pad_low_int64;
1338 int64_t pad_high_int64;
1339 int64_t input_size = input_ty.getDimSize(dim);
1340 if (input_size == ShapedType::kDynamicSize) return failure();
1341 tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerboseV2(
1342 input_size, filter_ty.getDimSize(i), dilation, stride, padding,
1343 &output_size, &pad_low_int64, &pad_high_int64);
1344 if (!status.ok()) return failure();
1345 pad_low = pad_low_int64;
1346 pad_high = pad_high_int64;
1347 }
1348 paddings.push_back(pad_low);
1349 paddings.push_back(pad_high);
1350 }
1351
1352 auto rhs_dilations_attr = rewriter.getNamedAttr(
1353 "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter));
1354
1355 auto window_strides_attr = rewriter.getNamedAttr(
1356 "window_strides", GetI64ElementsAttr(window_strides, &rewriter));
1357
1358 auto dimension_numbers_attr = GetConvDimensionNumbersAttr(
1359 spatial_dim_indices, data_format, &rewriter);
1360
1361 const int64_t input_channels =
1362 GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format));
1363 if (input_channels == ShapedType::kDynamicSize) return failure();
1364 // Filters data_format is always HWIO so input channels dimension is after
1365 // all spatial dimensions.
1366 const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims);
1367 // TensorFlow convolution op verifies that the number of input channels is
1368 // divisible by the number of filter channels.
1369 // For depthwise convolution the feature_group_count argument would be set
1370 // to the input feature dimension.
1371 const int64_t feature_group_count =
1372 depthwise_conv ? input_channels : input_channels / filter_channels;
1373 auto feature_group_count_attr = rewriter.getNamedAttr(
1374 "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count));
1375
1376 auto batch_group_count_attr = rewriter.getNamedAttr(
1377 "batch_group_count", rewriter.getI64IntegerAttr(1));
1378
1379 RankedTensorType paddings_ty = RankedTensorType::get(
1380 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
1381 auto paddings_attr = rewriter.getNamedAttr(
1382 "padding", DenseElementsAttr::get<int64_t>(paddings_ty, paddings));
1383
1384 SmallVector<Value, 2> operands(op.getOperands());
1385 // Reshape the filter to {spatial_dims...., 1,in_channels *
1386 // channel_multiplier}
1387 if (depthwise_conv) {
1388 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
1389 llvm::SmallVector<int64_t, num_dims> new_shape(
1390 filter_shape.begin(), filter_shape.begin() + num_spatial_dims);
1391 new_shape.push_back(1);
1392 new_shape.push_back(filter_shape[num_spatial_dims] *
1393 filter_shape[num_spatial_dims + 1]);
1394 operands[1] = rewriter.create<mhlo::ReshapeOp>(
1395 op.getLoc(),
1396 RankedTensorType::get(new_shape, filter_ty.getElementType()),
1397 operands[1]);
1398 }
1399 NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr,
1400 dimension_numbers_attr, feature_group_count_attr,
1401 batch_group_count_attr, paddings_attr};
1402 rewriter.replaceOpWithNewOp<ConvolutionOp>(op, op.getType(), operands,
1403 llvm::makeArrayRef(attrs));
1404 return success();
1405 }
1406 };
1407
1408 using ConvertConv2DOp = ConvertConvOp<TF::Conv2DOp, /*num_spatial_dims=*/2>;
1409 using ConvertConv3DOp = ConvertConvOp<TF::Conv3DOp, /*num_spatial_dims=*/3>;
1410 using ConvertDepthConv2DOp =
1411 ConvertConvOp<TF::DepthwiseConv2dNativeOp, /*num_spatial_dims=*/2,
1412 /*depthwise_conv=*/true>;
1413
1414 // Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const.
1415 class ConvertPadOpDynamic : public OpRewritePattern<TF::PadV2Op> {
1416 public:
1417 using OpRewritePattern::OpRewritePattern;
1418 // TODO(disc): To recover static special case's performance with folding and
1419 // canonicalization.
matchAndRewrite(TF::PadV2Op op,PatternRewriter & rewriter) const1420 LogicalResult matchAndRewrite(TF::PadV2Op op,
1421 PatternRewriter &rewriter) const override {
1422 Location loc = op.getLoc();
1423 auto input = op.input();
1424 auto paddings = op.paddings();
1425 auto constant_values = op.constant_values();
1426 auto input_type = input.getType().dyn_cast<RankedTensorType>();
1427 auto paddings_type = paddings.getType().dyn_cast<RankedTensorType>();
1428 if (!input_type || !paddings_type || !paddings_type.hasStaticShape())
1429 return failure();
1430
1431 // TODO(disc): Remove this constraint once fold and canonicalization is
1432 // implemented.
1433 if (input_type.hasStaticShape()) return failure();
1434
1435 int input_rank = input_type.getRank();
1436 // interior padding
1437 std::vector<int64_t> interior_values(input_rank, 0);
1438 auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter);
1439
1440 Value interior_padding_tensor =
1441 rewriter.create<mhlo::ConstantOp>(loc, interior_attr);
1442 Type paddings_elem_ty = paddings_type.getElementType();
1443 if (!paddings_elem_ty.isInteger(64)) {
1444 interior_padding_tensor = rewriter.create<mhlo::ConvertOp>(
1445 loc, interior_padding_tensor, paddings_elem_ty);
1446 }
1447 llvm::SmallVector<int64_t, 2> transposed_shape = {2, input_rank};
1448 auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter);
1449 Value transposed_paddings =
1450 rewriter.create<mhlo::TransposeOp>(loc, paddings, transpose_attr);
1451 Value reshaped_paddings = rewriter.create<mhlo::ReshapeOp>(
1452 loc, RankedTensorType::get({input_rank * 2}, paddings_elem_ty),
1453 transposed_paddings);
1454
1455 auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter);
1456 auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter);
1457 auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter);
1458 Value left_padding_tensor = rewriter.create<mhlo::SliceOp>(
1459 loc, reshaped_paddings, left_padding_start_attr,
1460 left_padding_limit_attr, left_padding_stride_attr);
1461
1462 auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter);
1463 auto right_padding_limit_attr =
1464 GetI64ElementsAttr({2 * input_rank}, &rewriter);
1465 auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter);
1466 Value right_padding_tensor = rewriter.create<mhlo::SliceOp>(
1467 loc, reshaped_paddings, right_padding_start_attr,
1468 right_padding_limit_attr, right_padding_stride_attr);
1469
1470 rewriter.replaceOpWithNewOp<mhlo::DynamicPadOp>(
1471 op, op.getType(), input, constant_values, left_padding_tensor,
1472 right_padding_tensor, interior_padding_tensor);
1473
1474 return success();
1475 }
1476 };
1477
1478 class ConvertGatherNdOpDynamic : public OpRewritePattern<TF::GatherNdOp> {
1479 using OpRewritePattern<TF::GatherNdOp>::OpRewritePattern;
1480 // Converts tf.GatherNdOp to mhlo.DynamicGatherOp.
1481 // Here we leave 'slice_sizes' as an Attr, without defining a new
1482 // DynamicGatherOp, since GatherDimensionNumbers has already provide enough
1483 // information for shape inference and code generation of mhlo::GatherOp. '?'
1484 // will be filled into slice_sizes for dimensions that are dynamic sized.
1485 // TODO(disc): To recover static special case's performance with folding and
1486 // canonicalization.
matchAndRewrite(TF::GatherNdOp op,PatternRewriter & rewriter) const1487 LogicalResult matchAndRewrite(TF::GatherNdOp op,
1488 PatternRewriter &rewriter) const override {
1489 Location loc = op.getLoc();
1490 auto params = op.params();
1491 auto params_ty = params.getType().dyn_cast<RankedTensorType>();
1492 auto indices = op.indices();
1493 auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
1494 auto params_rank = params_ty.getRank();
1495 auto indices_rank = indices_ty.getRank();
1496 int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1);
1497 if (!params_ty || !indices_ty) return failure();
1498 // the last dim of indices of GatherNdOp must be fixed shaped
1499 if (num_index_dims == ShapedType::kDynamicSize) return failure();
1500
1501 SmallVector<int64_t, 4> slice_sizes;
1502 slice_sizes.reserve(params_rank);
1503 for (int64_t i = 0; i < params_rank; ++i) {
1504 if (i < num_index_dims) {
1505 slice_sizes.push_back(1);
1506 } else {
1507 // potentially dynamic
1508 int64_t dim_size = params_ty.getDimSize(i);
1509 slice_sizes.push_back(dim_size);
1510 }
1511 }
1512 SmallVector<Value, 4> slice_sizes_vals;
1513 Value slice_sizes_value = nullptr;
1514 for (int64_t i = 0; i < params_rank; ++i) {
1515 if (i < num_index_dims) {
1516 slice_sizes_vals.push_back(rewriter.create<arith::ConstantOp>(
1517 loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1)));
1518 } else {
1519 int64_t dim_size = params_ty.getDimSize(i);
1520 if (dim_size != ShapedType::kDynamicSize) {
1521 slice_sizes_vals.push_back(rewriter.create<arith::ConstantOp>(
1522 loc,
1523 rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size)));
1524 } else {
1525 slice_sizes_vals.push_back(rewriter.create<arith::IndexCastOp>(
1526 loc, indices_ty.getElementType(),
1527 rewriter.create<tensor::DimOp>(loc, params, i)));
1528 }
1529 }
1530 }
1531 slice_sizes_value =
1532 rewriter.create<tensor::FromElementsOp>(loc, slice_sizes_vals);
1533
1534 // collapsed_slice_dims
1535 SmallVector<int64_t, 4> collapsed_slice_dims;
1536 collapsed_slice_dims.reserve(num_index_dims);
1537 for (int64_t i = 0; i < num_index_dims; ++i) {
1538 collapsed_slice_dims.push_back(i);
1539 }
1540 // offset_dims
1541 SmallVector<int64_t, 4> offset_dims;
1542 offset_dims.reserve(params_rank - num_index_dims);
1543 for (int64_t i = num_index_dims; i < params_rank; i++) {
1544 offset_dims.push_back(i + indices_rank - 1 - num_index_dims);
1545 }
1546 // start_index_map
1547 SmallVector<int64_t, 4> start_index_map;
1548 offset_dims.reserve(num_index_dims);
1549 for (int64_t i = 0; i < num_index_dims; i++) {
1550 start_index_map.push_back(i);
1551 }
1552 // index_vector_dim
1553 int64_t index_vector_dim = indices_rank - 1;
1554
1555 auto dims_attr = GatherDimensionNumbersAttr::get(
1556 rewriter.getContext(), offset_dims, collapsed_slice_dims,
1557 start_index_map, index_vector_dim);
1558 // TODO(disc): Remove this if-statement once fold and canonicalization is
1559 // implemented.
1560 if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) {
1561 rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
1562 op, op.getType(), op.params(), op.indices(), dims_attr,
1563 GetI64ElementsAttr(slice_sizes, &rewriter));
1564 } else {
1565 rewriter.replaceOpWithNewOp<mhlo::DynamicGatherOp>(
1566 op, op.getType(), op.params(), op.indices(), slice_sizes_value,
1567 dims_attr);
1568 }
1569 return success();
1570 }
1571 };
1572
1573 // Converts BF16 FloorDiv op to have casting operators on either end as BF16
1574 // division can result in strange behavior.
1575 //
1576 // floordiv = cast(floordiv(cast(left), cast(right))))
1577 //
1578 // %left_cast = cast(%left)
1579 // %right_cast = cast(%right)
1580 // %div = div(%left, %left)
1581 // %floored = floor(%div)
1582 // %floored_cast = cast(%floored)
1583 //
1584 // Required to manually specify the intermediate types.
1585 class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
1586 public:
1587 using OpRewritePattern::OpRewritePattern;
1588
matchAndRewrite(TF::FloorDivOp op,PatternRewriter & rewriter) const1589 LogicalResult matchAndRewrite(TF::FloorDivOp op,
1590 PatternRewriter &rewriter) const override {
1591 auto l = op.x();
1592 auto r = op.y();
1593 auto element_type = getElementTypeOrSelf(l.getType());
1594 if (!element_type.isBF16()) return failure();
1595
1596 auto out_type = op.z().getType().cast<TensorType>();
1597
1598 l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
1599 r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
1600
1601 auto intermediate = rewriter.create<TF::FloorDivOp>(
1602 op.getLoc(),
1603 ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l,
1604 r);
1605
1606 auto floor_op =
1607 rewriter.create<ConvertOp>(op.getLoc(), out_type, intermediate);
1608 rewriter.replaceOp(op, floor_op.getResult());
1609 return success();
1610 }
1611 };
1612
1613 class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
1614 public:
1615 using OpRewritePattern::OpRewritePattern;
1616
matchAndRewrite(TF::BroadcastToOp op,PatternRewriter & rewriter) const1617 LogicalResult matchAndRewrite(TF::BroadcastToOp op,
1618 PatternRewriter &rewriter) const override {
1619 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1620 auto output_type = op.output().getType();
1621 if (!input_type) {
1622 return rewriter.notifyMatchFailure(op, "requires ranked input shape");
1623 }
1624 llvm::SmallVector<int64_t, 4> broadcast_dimensions;
1625 if (input_type.getRank() > 0) {
1626 auto ranked_output_type = output_type.dyn_cast<RankedTensorType>();
1627 if (!ranked_output_type) {
1628 return rewriter.notifyMatchFailure(op, "requires ranked output shape");
1629 }
1630 auto rank_diff = ranked_output_type.getRank() - input_type.getRank();
1631 // The tf.BroadcastTo op performs "right-aligned" numpy-style
1632 // broadcasting.
1633 broadcast_dimensions = llvm::to_vector<4>(
1634 llvm::seq<int64_t>(rank_diff, ranked_output_type.getRank()));
1635 }
1636 rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
1637 op, output_type, op.input(), op.shape(),
1638 rewriter.getI64TensorAttr(broadcast_dimensions));
1639 return success();
1640 }
1641 };
1642
1643 /// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis
1644 /// have to be a constant.
1645 class ConvertRollOp : public OpRewritePattern<TF::RollOp> {
1646 public:
1647 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::RollOp op,PatternRewriter & rewriter) const1648 LogicalResult matchAndRewrite(TF::RollOp op,
1649 PatternRewriter &rewriter) const override {
1650 auto shift_ty = op.shift().getType().dyn_cast<RankedTensorType>();
1651 if (!shift_ty || shift_ty.getRank() != 0) {
1652 return rewriter.notifyMatchFailure(
1653 op, "require the type of shift to be 0D tensor");
1654 }
1655
1656 APInt val;
1657 if (!matchPattern(op.axis(), m_ConstantInt(&val))) {
1658 return rewriter.notifyMatchFailure(op, "require axis to be constant");
1659 }
1660 int axis = val.getSExtValue();
1661
1662 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
1663 if (!input_ty || !input_ty.hasStaticShape()) {
1664 return rewriter.notifyMatchFailure(
1665 op, "require the type of input to have static shapes");
1666 }
1667 ArrayRef<int64_t> input_shape = input_ty.getShape();
1668 int input_rank = input_ty.getRank();
1669 if (axis < 0) axis += input_rank;
1670
1671 // Adjust large offsets into [0, axis_size). This also makes negative
1672 // offsets positive.
1673 // offset = ((offset % axis_size) + axis_size) % axis_size
1674 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1675 Value offset = op.shift();
1676 auto axis_size = b.create<mhlo::ConstantOp>(b.getIntegerAttr(
1677 getElementTypeOrSelf(offset.getType()), input_shape[axis]));
1678 offset = b.create<RemOp>(
1679 b.create<AddOp>(b.create<RemOp>(offset, axis_size), axis_size),
1680 axis_size);
1681
1682 // Stack two copies of the dimension, then slice from the calculated
1683 // offset. This also works if shift is not constant.
1684 // DynamicSliceOp requires the sizes being integer, and we can get the
1685 // information from input shape.
1686 auto concat = b.create<ConcatenateOp>(ValueRange{op.input(), op.input()},
1687 b.getI64IntegerAttr(axis));
1688 Value zero = b.create<mhlo::ConstantOp>(
1689 b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0));
1690 SmallVector<Value> slice_begin_indices(input_rank, zero);
1691 slice_begin_indices[axis] = b.create<SubtractOp>(axis_size, offset);
1692 rewriter.replaceOpWithNewOp<DynamicSliceOp>(
1693 op, input_ty, concat, slice_begin_indices,
1694 rewriter.getI64TensorAttr(input_shape));
1695 return success();
1696 }
1697 };
1698
1699 /// Converts a TF::LeakyReluOp to HLO.
1700 /// LeakyRelu(x) = alpha * x if x < 0 else x.
1701 class ConvertLeakyReluOp : public OpRewritePattern<TF::LeakyReluOp> {
1702 public:
1703 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::LeakyReluOp op,PatternRewriter & rewriter) const1704 LogicalResult matchAndRewrite(TF::LeakyReluOp op,
1705 PatternRewriter &rewriter) const override {
1706 Location loc = op.getLoc();
1707 Value features = op.features();
1708
1709 // Use ConstantLike for `alpha` to match the shape of feature.
1710 auto alphaVal = chlo::getConstantLike(
1711 rewriter, loc, op.alpha().convertToFloat(), features);
1712 Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);
1713
1714 Value leakyActivationVal =
1715 rewriter.create<mhlo::MulOp>(loc, features, alphaVal);
1716
1717 Value compareGtZero = rewriter.create<mhlo::CompareOp>(
1718 loc, features, zeroVal, ComparisonDirection::GT);
1719
1720 rewriter.replaceOpWithNewOp<SelectOp>(op, compareGtZero, features,
1721 leakyActivationVal);
1722 return success();
1723 }
1724 };
1725
1726 /// Converts a TF::LeakyReluGradOp to HLO.
1727 /// LeakyReluGrad(gradient, inputs) = gradient if input > 0
1728 /// else alpha * gradient.
1729 class ConvertLeakyReluGradOp : public OpRewritePattern<TF::LeakyReluGradOp> {
1730 public:
1731 using OpRewritePattern::OpRewritePattern;
matchAndRewrite(TF::LeakyReluGradOp op,PatternRewriter & rewriter) const1732 LogicalResult matchAndRewrite(TF::LeakyReluGradOp op,
1733 PatternRewriter &rewriter) const override {
1734 Location loc = op.getLoc();
1735 Value gradients = op.gradients();
1736 Value features = op.features();
1737 auto featureType = features.getType();
1738
1739 // Use ConstantLike for `alpha` to match the shape of feature.
1740 auto alphaVal = chlo::getConstantLike(
1741 rewriter, loc, op.alpha().convertToFloat(), features);
1742 Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features);
1743
1744 Value leakyGradientVal =
1745 rewriter.create<mhlo::MulOp>(loc, gradients, alphaVal);
1746
1747 Value compareGtZero = rewriter.create<mhlo::CompareOp>(
1748 loc, features, zeroVal, ComparisonDirection::GT);
1749
1750 rewriter.replaceOpWithNewOp<SelectOp>(op, featureType, compareGtZero,
1751 gradients, leakyGradientVal);
1752 return success();
1753 }
1754 };
1755
1756 // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix.
1757 // For a Rank-2 input, it creates the following ops:
1758 // %1 = "mhlo.iota"() {iota_dimension = 0 : i64}
1759 // %2 = "mhlo.iota"() {iota_dimension = 1 : i64}
1760 // %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"}
1761 // %4 = mhlo.constant dense<0.000000e+00> : tensor<f32>
1762 // %5 = "mhlo.broadcast"(%4)
1763 // %6 = "mhlo.select"(%3, %input, %5)
1764 // %7 = "mhlo.reduce"(%6, %4) ({
1765 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
1766 // %9 = mhlo.add %arg1, %arg2 : tensor<f32>
1767 // "mhlo.return"(%9) : (tensor<f32>) -> ()
1768 // }) {dimensions = dense<0> : tensor<1xi64>}
1769 //
1770 // If the input's rank N is greater than 2, we will reshape it to R2 first and
1771 // create the above ops, then reshape it back to rank N/2.
1772 class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
1773 public:
1774 using OpRewritePattern::OpRewritePattern;
1775
matchAndRewrite(TF::DiagPartOp op,PatternRewriter & rewriter) const1776 LogicalResult matchAndRewrite(TF::DiagPartOp op,
1777 PatternRewriter &rewriter) const override {
1778 auto input_type = op.input().getType().dyn_cast<RankedTensorType>();
1779 if (!input_type || !input_type.hasStaticShape()) return failure();
1780 int64_t num_dims = input_type.getRank();
1781 if (num_dims < 2 || num_dims % 2 != 0) return failure();
1782 const int64_t out_dims = num_dims / 2;
1783
1784 int64_t new_size = 1;
1785 llvm::SmallVector<int64_t, 4> new_dims;
1786 for (int i = 0; i < out_dims; i++) {
1787 if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims))
1788 return op.emitOpError("invalid dimensions size");
1789 new_size *= input_type.getDimSize(i);
1790 new_dims.push_back(input_type.getDimSize(i));
1791 }
1792 Value reshaped_input = rewriter.create<mhlo::ReshapeOp>(
1793 op.getLoc(),
1794 RankedTensorType::get({new_size, new_size},
1795 input_type.getElementType()),
1796 op.input());
1797 auto iota_type = RankedTensorType::get({new_size, new_size},
1798 rewriter.getIntegerType(32));
1799 auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
1800 rewriter.getI64IntegerAttr(0));
1801 auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
1802 rewriter.getI64IntegerAttr(1));
1803 Value compare = rewriter.create<CompareOp>(op.getLoc(), iota0, iota1,
1804 ComparisonDirection::EQ);
1805 Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(),
1806 0, &rewriter);
1807 Value zero_matrix = rewriter.create<BroadcastOp>(
1808 op.getLoc(), reshaped_input.getType(), zero,
1809 GetI64ElementsAttr({new_size, new_size}, &rewriter));
1810 Value masked =
1811 rewriter.create<SelectOp>(op.getLoc(), reshaped_input.getType(),
1812 compare, reshaped_input, zero_matrix);
1813 auto reduce = rewriter.create<ReduceOp>(op.getLoc(), masked, zero,
1814 GetI64ElementsAttr({0}, &rewriter));
1815 assert(!input_type.getElementType().isInteger(1) &&
1816 "data type should not be i1");
1817 BuildReduceBody<AddOp>(input_type.getElementType(), &reduce.body(),
1818 &rewriter);
1819 rewriter.replaceOpWithNewOp<ReshapeOp>(
1820 op, RankedTensorType::get(new_dims, input_type.getElementType()),
1821 reduce.getResult(0));
1822 return success();
1823 }
1824 };
1825
1826 // Converts TensorFlow MatrixDiagPartOp to HLO ops.
1827 class ConvertMatrixDiagPartV3Op
1828 : public OpRewritePattern<TF::MatrixDiagPartV3Op> {
1829 using Shape = llvm::SmallVector<int64_t, 4>;
1830
1831 // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s)
1832 // with k. This can be either a single value (for a single diagonal) or a
1833 // tuple of two values (starting and ending diagonal, for a band).
ExtractK(TF::MatrixDiagPartV3Op op,int64_t (* k)[2]) const1834 LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const {
1835 DenseIntElementsAttr kattr;
1836 if (!matchPattern(op.k(), m_Constant(&kattr))) {
1837 return failure();
1838 }
1839 DenseIntElementsAttr::iterator it = kattr.begin();
1840 (*k)[0] = (*it).getSExtValue();
1841 it++;
1842 if (it == kattr.end()) {
1843 // Handle input like e.g. "k = 5", in which case we extract a single
1844 // diagonal.
1845 (*k)[1] = (*k)[0];
1846 } else {
1847 // Handle input like e.g. "k = [-1, 1]", in which case we extract a
1848 // band (multiple diagonals).
1849 (*k)[1] = (*it).getSExtValue();
1850 }
1851 return success();
1852 }
1853
1854 // Utility method for broadcasting integer constants to a given shape.
BroadcastConstant(Location loc,Shape shape,int32_t constant,int int_size,PatternRewriter & rewriter) const1855 BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant,
1856 int int_size, PatternRewriter &rewriter) const {
1857 return rewriter.create<BroadcastOp>(
1858 loc, RankedTensorType::get(shape, rewriter.getIntegerType(int_size)),
1859 GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant,
1860 &rewriter),
1861 GetI64ElementsAttr(shape, &rewriter));
1862 }
1863
1864 public:
1865 using OpRewritePattern::OpRewritePattern;
1866
matchAndRewrite(TF::MatrixDiagPartV3Op op,PatternRewriter & rewriter) const1867 LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op,
1868 PatternRewriter &rewriter) const override {
1869 Location loc = op.getLoc();
1870 ShapedType input_type = op.input().getType().dyn_cast<ShapedType>();
1871
1872 // Align is a string specifying how superdiagonals and subdiagonals should
1873 // be aligned/padded for diagonals that are shorter than max_diag_len. The
1874 // format is "{super}_{sub}", with {super} the superdiagonal alignment and
1875 // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the
1876 // left, "RIGHT" means rows will be padded ot the right. The default is
1877 // "RIGHT_LEFT".
1878 StringRef align = op->getAttrOfType<StringAttr>("align").getValue();
1879 enum Alignment { kLeft, kRight };
1880
1881 // default is RIGHT_LEFT
1882 Alignment superdiagonal_align = kRight;
1883 Alignment subdiagonal_align = kLeft;
1884
1885 if (align == "RIGHT_LEFT") {
1886 superdiagonal_align = kRight;
1887 subdiagonal_align = kLeft;
1888 } else if (align == "RIGHT_RIGHT") {
1889 superdiagonal_align = kRight;
1890 subdiagonal_align = kRight;
1891 } else if (align == "LEFT_RIGHT") {
1892 superdiagonal_align = kLeft;
1893 subdiagonal_align = kRight;
1894 } else if (align == "LEFT_LEFT") {
1895 superdiagonal_align = kLeft;
1896 subdiagonal_align = kLeft;
1897 } else {
1898 return failure(); // unsupported alignment
1899 }
1900
1901 // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and
1902 // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L].
1903 if (!input_type || !input_type.hasStaticShape()) return failure();
1904 int64_t num_dims = input_type.getRank();
1905 if (num_dims < 2) return failure();
1906 int64_t rows = input_type.getDimSize(num_dims - 2); // rows
1907 int64_t cols = input_type.getDimSize(num_dims - 1); // cols
1908
1909 // We extract the diagonals from k[0] up to and including k[1].
1910 // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract
1911 // the main diagonal). It's negative for subdiagonals (under and to the left
1912 // of the main diagonal) and positive for superdiagonals (above and to the
1913 // right of the main diagonal).
1914 int64_t k[2];
1915 if (failed(ExtractK(op, &k))) return failure();
1916 int num_diags = k[1] - k[0] + 1;
1917
1918 // Shifting diagonals away from the main diagonal might shorten them. This
1919 // is the longest diagonal we will see. We make this the last dimension of
1920 // the output shape.
1921 int64_t max_diag_len =
1922 std::min(rows + std::min(k[1], static_cast<int64_t>(0)),
1923 cols + std::min(-k[0], static_cast<int64_t>(0)));
1924
1925 // The first dimension is the index vector dimension we'll use for gather.
1926 // It's 1 here, but will be 2 once we glue x and y together.
1927 Shape indices_shape({1, num_diags, max_diag_len});
1928
1929 RankedTensorType iota_type =
1930 RankedTensorType::get(indices_shape, rewriter.getIntegerType(32));
1931 Value iotaM =
1932 rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(1));
1933 Value iotaN =
1934 rewriter.create<IotaOp>(loc, iota_type, rewriter.getI64IntegerAttr(2));
1935
1936 // Boradcasted constants, of the same shape as iotaM and iotaN.
1937 Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter);
1938 Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter);
1939 Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter);
1940 Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter);
1941 Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter);
1942 Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter);
1943 Value b_max_diag_len =
1944 BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter);
1945
1946 // d = k[1] - m
1947 // (A.k.a. the number of the diagonal, depending on m. Note that we
1948 // subtract m here. This means we start with the superdiagonals and
1949 // move downwards towards the subdiagonals. So the start indices will
1950 // be decreasing.)
1951 Value d = rewriter.create<SubtractOp>(loc, b_k1, iotaM);
1952 Value neg_d = rewriter.create<NegOp>(loc, d);
1953
1954 // diag_len_d = min(rows + min(d, 0), cols - max(d, 0))
1955 // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.)
1956 Value diag_len_d = rewriter.create<MinOp>(
1957 loc,
1958 rewriter.create<AddOp>(loc, b_rows,
1959 rewriter.create<MinOp>(loc, d, b_zero)),
1960 rewriter.create<SubtractOp>(loc, b_cols,
1961 rewriter.create<MaxOp>(loc, d, b_zero)));
1962
1963 // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise.
1964 Value cmp;
1965 if (subdiagonal_align == kRight && superdiagonal_align == kRight) {
1966 cmp = b_true;
1967 } else if (superdiagonal_align == kRight) {
1968 // offset = d>=0 ? max_diag_len - diag_len_d : 0
1969 cmp = rewriter.create<TF::GreaterEqualOp>(loc, d, b_zero);
1970 } else if (subdiagonal_align == kRight) {
1971 // offset = d<=0 ? max_diag_len - diag_len_d : 0
1972 cmp = rewriter.create<TF::LessEqualOp>(loc, d, b_zero);
1973 } else {
1974 // offset = 0
1975 cmp = b_false;
1976 }
1977
1978 // This offset shifts the diagonals to the "left" or "right", depending
1979 // on alignment.
1980 Value offset = rewriter.create<SelectOp>(
1981 loc, b_zero.getType(), cmp,
1982 rewriter.create<SubtractOp>(loc, b_max_diag_len, diag_len_d), b_zero);
1983
1984 // x = max(d, 0) - offset
1985 // y = max(-d, 0) - offset
1986 Value x = rewriter.create<SubtractOp>(
1987 loc, rewriter.create<MaxOp>(loc, d, b_zero), offset);
1988 Value y = rewriter.create<SubtractOp>(
1989 loc, rewriter.create<MaxOp>(loc, neg_d, b_zero), offset);
1990
1991 Value n_plus_x = rewriter.create<AddOp>(loc, iotaN, x);
1992 Value n_plus_y = rewriter.create<AddOp>(loc, iotaN, y);
1993
1994 // GatherOp is happy about letting us index out of bounds values, but those
1995 // values will be undefined. So we mask them later. Set up the boolean
1996 // expression that tells us which entries, in the output shape, are out of
1997 // bounds and thus become the padding_value.
1998 Value x_in_bounds = rewriter.create<AndOp>(
1999 loc,
2000 rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_x,
2001 b_zero),
2002 rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_x, b_cols));
2003 Value y_in_bounds = rewriter.create<AndOp>(
2004 loc,
2005 rewriter.create<TF::GreaterEqualOp>(loc, b_false.getType(), n_plus_y,
2006 b_zero),
2007 rewriter.create<TF::LessOp>(loc, b_false.getType(), n_plus_y, b_rows));
2008 Value in_bounds = rewriter.create<ReshapeOp>(
2009 loc,
2010 RankedTensorType::get(Shape({num_diags, max_diag_len}),
2011 rewriter.getIntegerType(1)),
2012 rewriter.create<AndOp>(loc, x_in_bounds, y_in_bounds));
2013
2014 // Now combine x and y into the index data structure needed for gather.
2015 Shape concat_shape({2, num_diags, max_diag_len});
2016 Value start_indices = rewriter.create<ConcatenateOp>(
2017 loc, RankedTensorType::get(concat_shape, rewriter.getIntegerType(32)),
2018 mlir::ValueRange({n_plus_y, n_plus_x}),
2019 mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0));
2020
2021 // Shape of the final output. (Except for dimension folding in the
2022 // single diagonal case.)
2023 Shape output_shape;
2024 for (int i = 0; i < num_dims - 2; i++) {
2025 output_shape.push_back(input_type.getDimSize(i));
2026 }
2027 output_shape.push_back(num_diags);
2028 output_shape.push_back(max_diag_len);
2029
2030 // A slice is the shape of what GatherOp copies per lookup. So the last
2031 // two dimensions (M, N in the matrix-diag-part docs) are where we go
2032 // through entry by entry.
2033 ArrayRef<int64_t> input_shape = input_type.getShape();
2034 Shape slice_sizes(input_shape.begin(), input_shape.end());
2035 int slice_dimensions = slice_sizes.size();
2036 slice_sizes[slice_dimensions - 2] = 1;
2037 slice_sizes[slice_dimensions - 1] = 1;
2038
2039 // Dimensions of the input we won't see in the output (M and N).
2040 SmallVector<int64_t, 2> collapsed_dims(
2041 {slice_dimensions - 2, slice_dimensions - 1});
2042
2043 // Which dimensions (in the input) the two offset "columns" map to.
2044 SmallVector<int64_t, 2> start_index_map({num_dims - 2, num_dims - 1});
2045
2046 // Gather the diagonal entries.
2047 // TODO(kramm): For a single diagonal, this might be slower than the
2048 // mask + sum approach. Special-case num_diags==1?
2049 auto dims_attr = GatherDimensionNumbersAttr::get(
2050 rewriter.getContext(),
2051 /*offset_dims=*/llvm::to_vector<4>(llvm::seq<int64_t>(0, num_dims - 2)),
2052 /*collapsed_slice_dims=*/collapsed_dims, start_index_map,
2053 /*index_vector_dim=*/0);
2054 Value gather = rewriter.create<mhlo::GatherOp>(
2055 loc, op.input(), start_indices, dims_attr,
2056 GetI64ElementsAttr(slice_sizes, &rewriter));
2057
2058 // We now need to broadcast the "in_bounds" boolean expression, as well as
2059 // the padding value, to do the final select.
2060 Shape broadcast_bounds;
2061 for (int i = 0; i < output_shape.size() - 2; i++) {
2062 broadcast_bounds.push_back(output_shape[i]);
2063 }
2064 Value b_in_bounds = rewriter.create<BroadcastOp>(
2065 loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)),
2066 in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter));
2067 Value b_padding = rewriter.create<BroadcastOp>(
2068 loc, op.padding_value(), GetI64ElementsAttr(output_shape, &rewriter));
2069
2070 // Replace all out-of-bounds values in the result with padding_value.
2071 Value result =
2072 rewriter.create<SelectOp>(loc, b_in_bounds, gather, b_padding);
2073
2074 if (num_diags == 1) {
2075 // matrix_diag_part folds away the 1-sized band dimension if we only
2076 // extract a single diagonal.
2077 result = rewriter.create<ReshapeOp>(loc, op.getType(), result);
2078 }
2079
2080 rewriter.replaceOp(op, result);
2081 return success();
2082 }
2083 };
2084
2085 // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp
2086 // depending on arity of the op.
2087 class ConvertEinsumOp : public OpRewritePattern<TF::EinsumOp> {
2088 public:
2089 using OpRewritePattern::OpRewritePattern;
2090
matchAndRewrite(TF::EinsumOp op,PatternRewriter & rewriter) const2091 LogicalResult matchAndRewrite(TF::EinsumOp op,
2092 PatternRewriter &rewriter) const override {
2093 StringAttr equation = op->getAttrOfType<StringAttr>("equation");
2094 if (op.N() == 1) {
2095 rewriter.replaceOpWithNewOp<UnaryEinsumOp>(
2096 op, op.getType(), *op.inputs().begin(), equation);
2097 } else if (op.N() == 2) {
2098 ValueRange inputs = op.inputs();
2099 rewriter.replaceOpWithNewOp<EinsumOp>(op, op.getType(), inputs[0],
2100 inputs[1], equation);
2101 } else {
2102 // TensorFlow EinsumOp verifies that the number of operands are at most
2103 // two.
2104 return failure();
2105 }
2106 return success();
2107 }
2108 };
2109
2110 // Bypasses IdentityN op.
2111 class ConvertIdentityNOp : public OpRewritePattern<TF::IdentityNOp> {
2112 public:
2113 using OpRewritePattern<TF::IdentityNOp>::OpRewritePattern;
matchAndRewrite(TF::IdentityNOp op,PatternRewriter & rewriter) const2114 LogicalResult matchAndRewrite(TF::IdentityNOp op,
2115 PatternRewriter &rewriter) const override {
2116 rewriter.replaceOp(op, op.getOperands());
2117 return success();
2118 }
2119 };
2120
2121 template <typename OpTy>
2122 class ConvertFFTOp : public OpRewritePattern<OpTy> {
2123 public:
2124 using OpRewritePattern<OpTy>::OpRewritePattern;
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2125 LogicalResult matchAndRewrite(OpTy op,
2126 PatternRewriter &rewriter) const override {
2127 auto input_ty = op.input().getType().template cast<ShapedType>();
2128 if (!input_ty.hasRank()) {
2129 return failure();
2130 }
2131 auto input_shape = input_ty.getShape();
2132 DenseIntElementsAttr fft_length_attr;
2133 if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) {
2134 return failure();
2135 }
2136 int64_t fft_length;
2137 if (fft_length_attr.getNumElements() != 0) {
2138 fft_length = fft_length_attr.getValues<IntegerAttr>()[0].getInt();
2139 } else {
2140 return failure();
2141 }
2142
2143 int64_t expected_dim = fft_length;
2144 std::string fft_string = "RFFT";
2145 if (typeid(OpTy) == typeid(TF::IRFFTOp)) {
2146 expected_dim = fft_length / 2 + 1;
2147 fft_string = "IRFFT";
2148 }
2149 Location loc = op.getLoc();
2150
2151 // The inner-most dim cannot be dynamic.
2152 if (input_ty.isDynamicDim(input_shape.size() - 1)) {
2153 return failure();
2154 }
2155
2156 auto expected_shape = llvm::to_vector<4>(input_shape.drop_back());
2157 expected_shape.push_back(expected_dim);
2158
2159 // Zero pad or truncate the last axis
2160 Value reshaped = op.input();
2161 SmallVector<int64_t, 4> begin_indices(input_shape.size(), 0);
2162 SmallVector<int64_t, 4> strides(input_shape.size(), 1);
2163
2164 // Last dim larger than expected_dim, slice the input
2165 if (input_shape.back() > expected_dim) {
2166 reshaped = rewriter.create<SliceOp>(
2167 op.getLoc(),
2168 RankedTensorType::get(expected_shape, input_ty.getElementType()),
2169 op.input(), GetI64ElementsAttr(begin_indices, &rewriter),
2170 GetI64ElementsAttr(expected_shape, &rewriter),
2171 GetI64ElementsAttr(strides, &rewriter));
2172
2173 // Last dim smaller than expected_dim, zero-pad the input
2174 } else if (input_ty.getShape().back() < expected_dim) {
2175 SmallVector<int64_t, 4> no_padding(input_shape.size(), 0);
2176 SmallVector<int64_t, 4> padding(input_shape.size() - 1, 0);
2177 padding.push_back(expected_dim - input_shape.back());
2178 Value zero =
2179 GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter);
2180 reshaped = rewriter.create<PadOp>(
2181 loc, RankedTensorType::get(expected_shape, input_ty.getElementType()),
2182 op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter),
2183 GetI64ElementsAttr(padding, &rewriter),
2184 GetI64ElementsAttr(no_padding, &rewriter));
2185 }
2186
2187 rewriter.replaceOpWithNewOp<FftOp>(
2188 op, op.getType(), reshaped,
2189 FftTypeAttr::get(rewriter.getContext(),
2190 symbolizeFftType(fft_string).getValue()),
2191 rewriter.getI64TensorAttr(fft_length));
2192 return success();
2193 }
2194 };
2195
2196 using ConvertRFFTOp = ConvertFFTOp<TF::RFFTOp>;
2197 using ConvertIRFFTOp = ConvertFFTOp<TF::IRFFTOp>;
2198
2199 // The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO
2200 // BatchNormGradOp for training and a sequence of binary ops for inference.
2201 // TODO(b/145536565): move to legalize_tf_patterns.td if it applies.
2202 template <typename FusedBatchNormGradOpT>
2203 class ConvertFusedBatchNormGradBase
2204 : public OpRewritePattern<FusedBatchNormGradOpT> {
2205 public:
2206 using OpRewritePattern<FusedBatchNormGradOpT>::OpRewritePattern;
2207
matchAndRewrite(FusedBatchNormGradOpT op,PatternRewriter & rewriter) const2208 LogicalResult matchAndRewrite(FusedBatchNormGradOpT op,
2209 PatternRewriter &rewriter) const override {
2210 Location loc = op.getLoc();
2211 Value grad = op.y_backprop();
2212 Value act = op.x();
2213 Value scale = op.scale();
2214 Value mean = op.reserve_space_1();
2215 Value var = op.reserve_space_2();
2216
2217 // TODO(b/141785544): Update this to not require static shapes.
2218 // activation shape needs to be static to convert negative indices in
2219 // TensorFlow to absolute indices required by HLO.
2220 RankedTensorType act_type =
2221 act.getType().template dyn_cast<RankedTensorType>();
2222 if (!act_type) return failure();
2223 Type act_ele_type = act_type.getElementType();
2224 // To support mixed precision, the statistics type, which maybe more
2225 // precise than the input types, are used for this op.
2226 Type kernel_type =
2227 scale.getType().template cast<TensorType>().getElementType();
2228 grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
2229 act = rewriter.create<ConvertOp>(loc, act, kernel_type);
2230
2231 tensorflow::TensorFormat data_format;
2232 if (!FormatFromString(op.data_format().str(), &data_format))
2233 return op.emitOpError("invalid data format");
2234
2235 auto feature_dim_attr = getFeatureDimensionAttr(rewriter, data_format, act);
2236 auto feature_dim = feature_dim_attr.getValue().getSExtValue();
2237
2238 // Gets the result values.
2239 Value x_backprop, scale_backprop, offset_backprop;
2240 if (op.is_training()) { // training
2241 // TODO(b/145536565): handle GPU logic separately.
2242 // Infers the output type with the converted `act`.
2243 Type feature_type = RankedTensorType::get(
2244 {GetDimSize(act_type, feature_dim)}, kernel_type);
2245
2246 SmallVector<Type, 3> operand_types = {act.getType(), feature_type,
2247 feature_type};
2248 auto training_op = rewriter.create<BatchNormGradOp>(
2249 loc, operand_types, act, scale, mean, var, grad, op.epsilon(),
2250 feature_dim);
2251
2252 x_backprop = training_op.getResult(0);
2253
2254 scale_backprop = training_op.getResult(1);
2255
2256 offset_backprop = training_op.getResult(2);
2257 } else { // inference
2258 SmallVector<int64_t, 4> non_feature_dims;
2259 for (int64_t i = 0; i < act_type.getRank(); ++i) {
2260 if (i == feature_dim) continue;
2261 non_feature_dims.push_back(i);
2262 }
2263 auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter);
2264 auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2265
2266 // scratch1 = rsqrt(var + epsilon)
2267 RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
2268 auto epsilon = rewriter.create<ConstantOp>(
2269 loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
2270 auto add_op = rewriter.create<chlo::BroadcastAddOp>(
2271 loc, var, epsilon.getResult(), scalar_broadcast_dims);
2272
2273 Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
2274
2275 // scratch2 = sum(y_backprop * (x - mean))
2276 auto sub_op = rewriter.create<mhlo::SubtractOp>(
2277 loc, act,
2278 Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
2279 auto weighted_grad = rewriter.create<mhlo::MulOp>(loc, grad, sub_op);
2280 Value scratch2 =
2281 ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
2282
2283 // x_backprop = y_backprop * (scale * scratch1)
2284 auto scaled_grad =
2285 rewriter.create<mhlo::MulOp>(loc, op.scale(), scratch1);
2286 x_backprop = rewriter.create<mhlo::MulOp>(
2287 loc, grad,
2288 Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
2289 rewriter));
2290
2291 // scale_backprop = scratch2 * scratch1
2292 scale_backprop = rewriter.create<mhlo::MulOp>(loc, scratch1, scratch2);
2293
2294 // offset_backprop = sum(y_backprop)
2295 offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
2296 }
2297
2298 x_backprop = rewriter.create<ConvertOp>(loc, x_backprop, act_ele_type);
2299 Value last_val[2];
2300 if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) {
2301 // It doesn't matter what values we provide for the last 2 results.
2302 last_val[0] = last_val[1] = op.x();
2303 } else {
2304 auto const_val = rewriter.create<ConstantOp>(
2305 op.getLoc(),
2306 DenseElementsAttr::get<float>(
2307 RankedTensorType::get({0}, getElementTypeOrSelf(op.getResult(3))),
2308 0.0));
2309 auto maybe_cast = [&](Value val, Type t) -> Value {
2310 if (val.getType() == t) return val;
2311 return rewriter.create<tensor::CastOp>(op.getLoc(), t, val);
2312 };
2313 last_val[0] = maybe_cast(const_val, op.getResult(3).getType());
2314 last_val[1] = maybe_cast(const_val, op.getResult(4).getType());
2315 }
2316 rewriter.replaceOp(
2317 op, {/*x_backprop=*/x_backprop,
2318 /*scale_backprop=*/scale_backprop,
2319 /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]});
2320 return success();
2321 }
2322 };
2323
2324 using ConvertFusedBatchNormGradOp =
2325 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradOp>;
2326 using ConvertFusedBatchNormGradV2Op =
2327 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV2Op>;
2328 using ConvertFusedBatchNormGradV3Op =
2329 ConvertFusedBatchNormGradBase<TF::FusedBatchNormGradV3Op>;
2330
2331 // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or
2332 // HLO BatchNormInferenceOp, depending on the value of the 'is_training'
2333 // parameter.
2334 template <typename FusedBatchNormOpT>
2335 class ConvertFusedBatchNormBase : public OpRewritePattern<FusedBatchNormOpT> {
2336 public:
2337 using OpRewritePattern<FusedBatchNormOpT>::OpRewritePattern;
2338
matchAndRewrite(FusedBatchNormOpT op,PatternRewriter & rewriter) const2339 LogicalResult matchAndRewrite(FusedBatchNormOpT op,
2340 PatternRewriter &rewriter) const override {
2341 tensorflow::TensorFormat data_format;
2342 if (!FormatFromString(op.data_format().str(), &data_format))
2343 return op.emitOpError("invalid data format");
2344
2345 auto feature_dim = getFeatureDimensionAttr(rewriter, data_format, op.x());
2346
2347 auto input_type_tensor = op.x().getType().template cast<TensorType>();
2348 auto input_element_type = input_type_tensor.getElementType();
2349
2350 auto scale_type_tensor = op.scale().getType().template cast<TensorType>();
2351 auto scale_element_type = scale_type_tensor.getElementType();
2352
2353 auto mean_type_tensor = op.mean().getType().template cast<TensorType>();
2354 auto mean_element_type = mean_type_tensor.getElementType();
2355 // In the training case, dimensions of input tensors must be static.
2356 if (op.is_training() && (!input_type_tensor.hasStaticShape() ||
2357 !scale_type_tensor.hasStaticShape() ||
2358 !mean_type_tensor.hasStaticShape()))
2359 return failure();
2360
2361 // TODO(b/69928690): Support mixed precision in the XLA batch
2362 // normalization operators. As a workaround, create a new x with the same
2363 // element type as scale (which may be more precise than the input type).
2364 Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
2365 scale_element_type);
2366 TensorType bn_train_input_type_tensor =
2367 bn_train_input.getType().template cast<TensorType>();
2368
2369 if (op.is_training()) {
2370 // Training case.
2371 auto operand_shape = bn_train_input_type_tensor.getShape();
2372 // The mean and variance are each 1 dimensional arrays the size of the
2373 // feature dimension, with the same element type as the operand (x).
2374 // This shape must be constructed manually because the mean and variance
2375 // inputs are empty in the training case.
2376 Type mean_var_type = RankedTensorType::get(
2377 {operand_shape[feature_dim.getInt()]}, scale_element_type);
2378 // Op result type is a tuple of 3 values: output with same shape as input;
2379 // batch_mean, and batch_var.
2380 SmallVector<Type, 3> operand_types = {bn_train_input_type_tensor,
2381 mean_var_type, mean_var_type};
2382 auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
2383 op.getLoc(), operand_types, bn_train_input, op.scale(), op.offset(),
2384 op.epsilon(), feature_dim.getInt());
2385 // HLO op outputs a tuple of tensors. Extract those results.
2386 Value y_out = bn_train_op.getResult(0);
2387 Value batch_mean = bn_train_op.getResult(1);
2388 Value reserve_space_1 = batch_mean;
2389 Value batch_variance = bn_train_op.getResult(2);
2390
2391 // Apply Bessel's correction on the variance.
2392 int total_input_size = bn_train_input_type_tensor.getNumElements();
2393 int total_scale_size = scale_type_tensor.getNumElements();
2394 int sample_size = total_input_size / total_scale_size;
2395 int sample_size_minus_one = std::max(1, sample_size - 1);
2396 double factor = static_cast<double>(sample_size) /
2397 static_cast<double>(sample_size_minus_one);
2398 auto factor_const_op = rewriter.create<mhlo::ConstantOp>(
2399 op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
2400
2401 Value corrected_variance = rewriter.create<chlo::BroadcastMulOp>(
2402 op.getLoc(), batch_variance.getType(), batch_variance,
2403 factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
2404
2405 // Convert back to input type to stay aligned with expected output type
2406 // for TF op.
2407 y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), y_out,
2408 input_element_type);
2409
2410 float exponential_avg_factor =
2411 op.exponential_avg_factor().convertToFloat();
2412 if (exponential_avg_factor != 1.0f) {
2413 auto alpha = rewriter.create<mhlo::ConstantOp>(
2414 op.getLoc(), rewriter.getFloatAttr(mean_element_type,
2415 1.0f - exponential_avg_factor));
2416 auto beta = rewriter.create<mhlo::ConstantOp>(
2417 op.getLoc(),
2418 rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
2419
2420 // new_running_mean = alpha * old_mean + beta * batch_mean.
2421 auto alpha_mul_old_mean = rewriter.create<chlo::BroadcastMulOp>(
2422 op.getLoc(), op.mean().getType(), alpha, op.mean(),
2423 /*broadcast_dimensions=*/DenseIntElementsAttr());
2424 auto beta_mul_batch_mean = rewriter.create<chlo::BroadcastMulOp>(
2425 op.getLoc(), batch_mean.getType(), beta, batch_mean,
2426 /*broadcast_dimensions=*/DenseIntElementsAttr());
2427 batch_mean = rewriter.create<chlo::BroadcastAddOp>(
2428 op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
2429 /*broadcast_dimensions=*/DenseIntElementsAttr());
2430
2431 // new_running_variance = alpha * old_variance + beta * batch_variance.
2432 auto alpha_mul_old_variance = rewriter.create<chlo::BroadcastMulOp>(
2433 op.getLoc(), op.variance().getType(), alpha, op.variance(),
2434 /*broadcast_dimensions=*/DenseIntElementsAttr());
2435 auto beta_mul_batch_variance = rewriter.create<chlo::BroadcastMulOp>(
2436 op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
2437 /*broadcast_dimensions=*/DenseIntElementsAttr());
2438 corrected_variance = rewriter.create<chlo::BroadcastAddOp>(
2439 op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
2440 /*broadcast_dimensions=*/DenseIntElementsAttr());
2441 }
2442
2443 if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2444 // FusedBatchNormV2 expects 4 outputs.
2445 // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2".
2446 // They are used to pass the per-batch mean and variance to the
2447 // gradiant. Here we maintain the same behavior by setting them to the
2448 // mean and variance calculated by BatchNormTraining.
2449 rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2450 /*batch_variance=*/corrected_variance,
2451 /*reserve_space_1=*/reserve_space_1,
2452 /*reserve_space_2=*/batch_variance});
2453 } else { // TF::FusedBatchNormV3Op
2454 // For FusedBatchNormV3Op, also create a constant tensor to forward to
2455 // last reserve_space_3 output.
2456 auto reserve_space_3_type =
2457 op.getResult(5).getType().template cast<TensorType>();
2458 int num_elements = reserve_space_3_type.hasStaticShape()
2459 ? reserve_space_3_type.getNumElements()
2460 : 0;
2461 auto const_attr_type = RankedTensorType::get(
2462 {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2463 Value dummy_const = rewriter.create<ConstantOp>(
2464 op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2465 if (const_attr_type != reserve_space_3_type)
2466 dummy_const = rewriter.create<tensor::CastOp>(
2467 op.getLoc(), reserve_space_3_type, dummy_const);
2468 rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean,
2469 /*batch_variance=*/corrected_variance,
2470 /*reserve_space_1=*/reserve_space_1,
2471 /*reserve_space_2=*/batch_variance,
2472 /*reserve_space_3=*/dummy_const});
2473 }
2474 } else { // Inference case.
2475 auto bn_train_op = rewriter.create<BatchNormInferenceOp>(
2476 op.getLoc(),
2477 /*result_type=*/bn_train_input_type_tensor, bn_train_input,
2478 op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(),
2479 feature_dim.getInt());
2480
2481 // Convert back to input type to stay aligned with expected output type
2482 // for TF op.
2483 auto y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), bn_train_op,
2484 input_element_type);
2485
2486 // The mean, variance, and reserved space outputs of the batch norm op are
2487 // not used for inference. It doesn't matter what values we provide for
2488 // the last 5 results as long as they are of the same type. Forward
2489 // input mean and variance to output mean, variance, reserved_space_1 and
2490 // reserved_space_2.
2491 if (std::is_same<FusedBatchNormOpT, TF::FusedBatchNormV2Op>::value) {
2492 rewriter.replaceOp(op, {/*y=*/y_out,
2493 /*batch_mean=*/op.mean(),
2494 /*batch_variance=*/op.variance(),
2495 /*reserve_space_1=*/op.mean(),
2496 /*reserve_space_2=*/op.variance()});
2497 } else {
2498 // For FusedBatchNormV3Op, also create a constant tensor to forward to
2499 // last reserve_space_3 output.
2500 auto reserve_space_3_type =
2501 op.getResult(5).getType().template cast<TensorType>();
2502 int num_elements = reserve_space_3_type.hasStaticShape()
2503 ? reserve_space_3_type.getNumElements()
2504 : 0;
2505 auto const_attr_type = RankedTensorType::get(
2506 {num_elements}, getElementTypeOrSelf(reserve_space_3_type));
2507 Value dummy_const = rewriter.create<ConstantOp>(
2508 op.getLoc(), DenseElementsAttr::get<float>(const_attr_type, 0.0));
2509 if (const_attr_type != reserve_space_3_type)
2510 dummy_const = rewriter.create<tensor::CastOp>(
2511 op.getLoc(), reserve_space_3_type, dummy_const);
2512 rewriter.replaceOp(op, {/*y=*/y_out,
2513 /*batch_mean=*/op.mean(),
2514 /*batch_variance=*/op.variance(),
2515 /*reserve_space_1=*/op.mean(),
2516 /*reserve_space_2=*/op.variance(),
2517 /*reserve_space_3=*/dummy_const});
2518 }
2519 }
2520 return success();
2521 }
2522 };
2523
2524 using ConvertFusedBatchNormV2Op =
2525 ConvertFusedBatchNormBase<TF::FusedBatchNormV2Op>;
2526 using ConvertFusedBatchNormV3Op =
2527 ConvertFusedBatchNormBase<TF::FusedBatchNormV3Op>;
2528
2529 using PaddingArray = std::vector<std::pair<int64_t, int64_t>>;
2530
2531 // Returns padding values for ReduceWindow op as a vector of pairs.
2532 //
2533 // Requires padding to be either 'SAME' or 'VALID' and the number of input
2534 // dimensions to be equal to the size of window dimensions and window strides.
2535 template <int num_dims>
GetReduceWindowPaddingAsArray(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2536 static PaddingArray GetReduceWindowPaddingAsArray(
2537 llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2538 ArrayAttr window_strides, StringRef padding, Builder *builder) {
2539 if (padding == "VALID") {
2540 return PaddingArray(num_dims, std::make_pair(0, 0));
2541 }
2542 assert(padding == "SAME");
2543 llvm::SmallVector<int64_t, num_dims> input_shape, window_shape, strides;
2544 input_shape.reserve(input_dims.size());
2545 window_shape.reserve(window_shape.size());
2546 strides.reserve(window_strides.size());
2547
2548 for (const auto &dim : input_dims) input_shape.push_back(dim);
2549 for (Attribute attr : window_dims)
2550 window_shape.push_back(attr.cast<IntegerAttr>().getInt());
2551 for (Attribute attr : window_strides)
2552 strides.push_back(attr.cast<IntegerAttr>().getInt());
2553
2554 PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides,
2555 ::xla::Padding::kSame);
2556 return paddings;
2557 }
2558
2559 // Same as GetReduceWindowPaddingAsArray but returns padding as
2560 // DenseIntElementsAttr. Returns empty attribute for `VALID` padding.
2561 template <int num_dims>
GetReduceWindowPaddingAsAttr(llvm::ArrayRef<int64_t> input_dims,ArrayAttr window_dims,ArrayAttr window_strides,StringRef padding,Builder * builder)2562 static DenseIntElementsAttr GetReduceWindowPaddingAsAttr(
2563 llvm::ArrayRef<int64_t> input_dims, ArrayAttr window_dims,
2564 ArrayAttr window_strides, StringRef padding, Builder *builder) {
2565 if (padding == "VALID") return {};
2566 assert(padding == "SAME");
2567 PaddingArray paddings = GetReduceWindowPaddingAsArray<num_dims>(
2568 input_dims, window_dims, window_strides, padding, builder);
2569 int64_t rank = paddings.size();
2570 llvm::SmallVector<int64_t, num_dims * 2> flatten_paddings(rank * 2);
2571 for (int i = 0; i < rank; i++) {
2572 flatten_paddings[2 * i] = paddings[i].first;
2573 flatten_paddings[2 * i + 1] = paddings[i].second;
2574 }
2575 return DenseIntElementsAttr::get(
2576 RankedTensorType::get({rank, 2}, builder->getIntegerType(64)),
2577 flatten_paddings);
2578 }
2579
2580 // Helper function for dividing each entry of `pooled` by the count of its
2581 // corresponding window, i.e., the number of non-padding entries of the window
2582 // which an `AvgPool` operation performed on an `input_shape`-tensor would map
2583 // to this entry, depending on `ksize` and `strides`. This function is used for
2584 // `AvgPool` and `AvgPoolGrad` legalizations.
2585 // `zero` is passed as a parameter because it can be reused from caller level.
2586 // `pooled` must have `RankedTensorType`.
2587 template <typename OpTy, int num_dims>
AvgPoolDivideByCount(Value pooled,const SmallVector<int64_t,num_dims> & input_shape,const SmallVector<int64_t,num_dims> & ksize,const SmallVector<int64_t,num_dims> & strides,OpTy op,Value zero,PatternRewriter & rewriter)2588 Operation *AvgPoolDivideByCount(
2589 Value pooled, const SmallVector<int64_t, num_dims> &input_shape,
2590 const SmallVector<int64_t, num_dims> &ksize,
2591 const SmallVector<int64_t, num_dims> &strides, OpTy op, Value zero,
2592 PatternRewriter &rewriter) {
2593 Location loc = op.getLoc();
2594 RankedTensorType pooled_type =
2595 pooled.getType().template cast<RankedTensorType>();
2596 Type element_type = pooled_type.getElementType();
2597 Operation *result = nullptr;
2598 RankedTensorType orig_input_type =
2599 RankedTensorType::get(input_shape, element_type);
2600
2601 if (op.padding() == "VALID") {
2602 // All window counts are equal here because we don't have padding
2603 // (each entry of `pooled` corresponds to a window that consists of
2604 // original input entries only).
2605 int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1,
2606 std::multiplies<int64_t>());
2607 // Divide `pooled` by window counts.
2608 Value divisor =
2609 GetScalarConstOfType(element_type, loc, window_count, &rewriter);
2610 auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
2611 result = rewriter.create<chlo::BroadcastDivOp>(
2612 loc, pooled_type, pooled, divisor, scalar_broadcast_dims);
2613 } else {
2614 assert(op.padding() == "SAME");
2615 // For SAME padding, only original entries that contributed to a window
2616 // are counted for the average of this window, not padded entries.
2617
2618 // Build all-ones tensor of same shape as the original input.
2619 ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1);
2620 auto all_ones_tensor = rewriter.create<ConstantOp>(loc, splat);
2621
2622 // Get padding for the input.
2623 DenseIntElementsAttr input_padding_attr =
2624 GetReduceWindowPaddingAsAttr<num_dims>(
2625 input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2626
2627 // Count the 1's in each window, using the same padding as for the input,
2628 // which gives us the window counts by which `pooled` needs to be divided.
2629 auto divisor = rewriter.create<ReduceWindowOp>(
2630 loc, pooled_type,
2631 /*operand=*/all_ones_tensor,
2632 /*init_value=*/zero,
2633 /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2634 /*window_strides=*/GetI64ElementsAttr(op.strides()),
2635 /*base_dilations=*/DenseIntElementsAttr(),
2636 /*window_dilations=*/DenseIntElementsAttr(),
2637 /*padding=*/input_padding_attr);
2638 BuildReduceBody<AddOp>(element_type, &divisor.body(), &rewriter);
2639
2640 // Divide `pooled` by window counts.
2641 result = rewriter.create<mhlo::DivOp>(loc, pooled_type, pooled,
2642 divisor.getResult(0));
2643 }
2644 return result;
2645 }
2646
GetAvgPoolInput(TF::AvgPoolOp op)2647 Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); }
GetAvgPoolInput(TF::AvgPool3DOp op)2648 Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); }
2649
2650 // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window
2651 // dimensions with add as the reduction function. The reduction result is
2652 // then divided by the number of elements in the window.
2653 template <typename OpTy, int num_dims>
2654 class ConvertAvgPoolOp : public OpRewritePattern<OpTy> {
2655 public:
2656 using OpRewritePattern<OpTy>::OpRewritePattern;
2657
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2658 LogicalResult matchAndRewrite(OpTy op,
2659 PatternRewriter &rewriter) const override {
2660 Value input_value = GetAvgPoolInput(op);
2661 auto input_type =
2662 input_value.getType().template dyn_cast<RankedTensorType>();
2663 if (!input_type) return failure();
2664
2665 // We will do accumulation first; use a larger bitwidth if suitable.
2666 Type input_element_type = input_type.getElementType();
2667 Type sum_element_type = GetSumAccumulationType(input_element_type);
2668 Type result_type;
2669
2670 // The result type for reduction and division with the proper element type.
2671 if (auto ranked_type = op.getType().template dyn_cast<RankedTensorType>())
2672 result_type =
2673 RankedTensorType::get(ranked_type.getShape(), sum_element_type);
2674 else
2675 result_type = UnrankedTensorType::get(sum_element_type);
2676
2677 // Convert if we need enlarge the element type's bitwidth.
2678 if (input_element_type != sum_element_type)
2679 input_value = rewriter.create<ConvertOp>(op.getLoc(), input_value,
2680 sum_element_type);
2681
2682 // Create the ReduceWindow op.
2683 Value init =
2684 GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter);
2685 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2686 input_type.getShape(), op.ksize(), op.strides(), op.padding(),
2687 &rewriter);
2688 auto reduce = rewriter.create<ReduceWindowOp>(
2689 op.getLoc(), result_type, input_value, init,
2690 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
2691 /*base_dilations=*/DenseIntElementsAttr(),
2692 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2693 BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter);
2694
2695 // Count the number of elements in the window. The following calculation
2696 // is only valid for no paddings.
2697 SmallVector<int64_t, num_dims> input_shape(
2698 llvm::to_vector<num_dims>(input_type.getShape()));
2699 SmallVector<int64_t, num_dims> ksize, strides;
2700 GetI64ArrayAttrValues(op.ksize(), &ksize);
2701 GetI64ArrayAttrValues(op.strides(), &strides);
2702
2703 Operation *result_op = AvgPoolDivideByCount<OpTy, num_dims>(
2704 reduce.getResult(0), input_shape, ksize, strides, op, init, rewriter);
2705
2706 // Convert back if we enlarged the element type's bitwidth.
2707 Value result = result_op->getOpResult(0);
2708 if (input_element_type != sum_element_type)
2709 result =
2710 rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
2711
2712 rewriter.replaceOp(op, result);
2713 return success();
2714 }
2715 };
2716
2717 using ConvertAvgPool2DOp = ConvertAvgPoolOp<TF::AvgPoolOp, /*num_dims=*/4>;
2718 using ConvertAvgPool3DOp = ConvertAvgPoolOp<TF::AvgPool3DOp, /*num_dims=*/5>;
2719
2720 // `AvgPoolGradOp` is converted to the following operations:
2721 // 1. Divide each entry of the output gradient (the gradient for the previous
2722 // layer in backpropagation order) by the count of the corresponding window
2723 // (i.e., the number of non-padding entries of the window which `AvgPool`
2724 // has mapped to this entry in forward propagation).
2725 // 2. Add appropriate interior and exterior padding for step 3 (see example
2726 // below).
2727 // 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape
2728 // as windows) and stride 1 in each dimension. This is implemented as a
2729 // `ReduceWindowOp` with `AddOp` as body.
2730 //
2731 // Example:
2732 // Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2,
2733 // and SAME padding with 0's. It is defined by
2734 // f(x) = [ (x_1 + x_2 + x_3) / 3 ] ( x = (x_1, x_2, x_3, x_4) )
2735 // [ (x_3 + x_4 + 0) / 2 ] (the 0 results from right padding)
2736 // Note that for SAME padding in `AvgPool` the padded entries are not counted
2737 // for the average, this is why the second denominator is 2 and not 3.
2738 // The Jacobian Df is
2739 // [ 1/3 1/3 1/3 0 ]
2740 // [ 0 0 1/2 1/2 ]
2741 //
2742 // Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only
2743 // needs the original input shape and not the tensor as argument).
2744 // Let v = [ 4 6 ]^T be the output gradient (^T = transposed). Then the
2745 // average pool gradient is given by
2746 // Df^T * v = [ 4/3 4/3 13/3 3 ]^T
2747 // Instead of a matrix-vector-multiplication we can utilize the sparsity and
2748 // structure of Df by using the 3-step approach from above:
2749 // 1. Divide output gradient v by window counts: [ 4/3 6/2 ]^T
2750 // 2. Add appropriate padding: [ 0 0 4/3 0 3 0 ]^T
2751 // 3. Convolve with kernel [ 1 1 1 ]: [ 4/3 4/3 11/3 3 ]^T
2752 //
2753 // Note that the padding in step 2. is chosen in such a way that the subsequent
2754 // convolution produces the gradient. Higher dimensions, different padding, and
2755 // different windows/strides work in a similar way, the main difference is in
2756 // the computation of the paddings in step 2.
2757 //
2758 // For more details on backpropagation for convolution of which `AvgPoolGrad`
2759 // is a special case see `tensorflow/core/kernels/conv_grad_ops.h`.
2760 // `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more
2761 // examples for different cases.
2762 template <typename OpTy, int num_dims>
2763 class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
2764 using DimVector = SmallVector<int64_t, num_dims>;
2765
2766 public:
2767 using OpRewritePattern<OpTy>::OpRewritePattern;
2768
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2769 LogicalResult matchAndRewrite(OpTy op,
2770 PatternRewriter &rewriter) const override {
2771 Location loc = op.getLoc();
2772 tensorflow::TensorFormat data_format;
2773 if (!FormatFromString(op.data_format().str(), &data_format)) {
2774 return op.emitOpError("invalid data format");
2775 }
2776 // `out_grad` is the gradient that was propagated via backpropagation from
2777 // the output layer.
2778 Value out_grad = op.grad();
2779 auto out_grad_type =
2780 out_grad.getType().template dyn_cast<RankedTensorType>();
2781 if (!out_grad_type) {
2782 return failure();
2783 }
2784 Type element_type = out_grad_type.getElementType();
2785 DenseIntElementsAttr orig_input_shape_attr;
2786 if (!matchPattern(op.orig_input_shape(),
2787 m_Constant(&orig_input_shape_attr))) {
2788 return failure();
2789 }
2790 auto orig_input_shape_values = orig_input_shape_attr.getValues<int32_t>();
2791 DimVector orig_input_shape(orig_input_shape_values.begin(),
2792 orig_input_shape_values.end());
2793 DimVector ksize, strides;
2794 GetI64ArrayAttrValues(op.ksize(), &ksize);
2795 GetI64ArrayAttrValues(op.strides(), &strides);
2796 Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter);
2797
2798 auto out_grad_divided = AvgPoolDivideByCount<OpTy, num_dims>(
2799 out_grad, orig_input_shape, ksize, strides, op, zero, rewriter);
2800
2801 // Get same padding as for original input.
2802 PaddingArray orig_padding = GetReduceWindowPaddingAsArray<num_dims>(
2803 orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter);
2804
2805 // Add padding around `out_grad_divided` values in such a way that the
2806 // subsequent `ReduceWindowOp` produces the gradient.
2807 DimVector out_grad_shape(
2808 llvm::to_vector<num_dims>(out_grad_type.getShape()));
2809 DimVector low_padding(num_dims, 0);
2810 DimVector high_padding(num_dims, 0);
2811 DimVector interior_padding(num_dims, 0);
2812 constexpr int num_spatial_dims = num_dims - 2;
2813 for (int i = 0; i < num_spatial_dims; ++i) {
2814 int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
2815 int orig_input_shape_padded_in_dim = orig_input_shape[dim] +
2816 orig_padding[dim].first +
2817 orig_padding[dim].second;
2818 // Set interior padding such that neighboring entries from
2819 // `out_grad_divided` have distance `strides[dim]` from each other in
2820 // every dimension.
2821 interior_padding[dim] = strides[dim] - 1;
2822 // Set exterior padding in the same way as for convolution gradient
2823 // computation.
2824 auto status = ::xla::ConvGradExtractAndVerifyDimension(
2825 /*input_size=*/orig_input_shape_padded_in_dim,
2826 /*filter_size=*/ksize[dim],
2827 /*output_size=*/out_grad_shape[dim],
2828 /*dilation=*/1,
2829 /*stride=*/strides[dim],
2830 /*padding=*/::xla::Padding::kValid);
2831 if (!status.ok()) {
2832 return failure();
2833 }
2834 ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim =
2835 status.ValueOrDie();
2836 // Subtract the original exterior padding since it doesn't contribute to
2837 // the gradient. Note that we save one `PadOp` and some unnecessary kernel
2838 // computations, compared to the `xla::AvgPoolGrad` implementation, by
2839 // subtracting the original exterior padding before `ReduceWindowOp`
2840 // instead of trimming the result of `ReduceWindowOp` (the final result is
2841 // the same because all strides are 1).
2842 low_padding[dim] =
2843 conv_grad_spatial_dim.pad_before - orig_padding[dim].first;
2844 high_padding[dim] =
2845 conv_grad_spatial_dim.pad_after - orig_padding[dim].second;
2846
2847 // Update `out_grad_shape` to result shape of following `PadOp`.
2848 out_grad_shape[dim] = low_padding[dim] + high_padding[dim] +
2849 (out_grad_shape[dim] - 1) * strides[dim] + 1;
2850 }
2851 Value reduce_window_input = rewriter.create<PadOp>(
2852 loc, RankedTensorType::get(out_grad_shape, element_type),
2853 /*operand=*/out_grad_divided->getOpResult(0),
2854 /*padding_value=*/zero,
2855 /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter),
2856 /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter),
2857 /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter));
2858
2859 // Compute result by convolving `reduce_window_input` with an all-ones
2860 // kernel, using `ReduceWindowOp` with `AddOp` body.
2861
2862 Type sum_element_type = GetSumAccumulationType(element_type);
2863 if (element_type != sum_element_type) {
2864 // Convert to appropriate sum accumulation type to avoid precision loss.
2865 reduce_window_input = rewriter.create<ConvertOp>(loc, reduce_window_input,
2866 sum_element_type);
2867 zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter);
2868 }
2869 auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter);
2870 auto reduce_window_op = rewriter.create<ReduceWindowOp>(
2871 loc, RankedTensorType::get(orig_input_shape, sum_element_type),
2872 /*operand=*/reduce_window_input,
2873 /*init_value=*/zero,
2874 /*window_dimensions=*/GetI64ElementsAttr(op.ksize()),
2875 /*window_strides=*/ones,
2876 /*base_dilations=*/DenseIntElementsAttr(),
2877 /*window_dilations=*/DenseIntElementsAttr(),
2878 /*padding=*/DenseIntElementsAttr());
2879 BuildReduceBody<AddOp>(sum_element_type, &reduce_window_op.body(),
2880 &rewriter);
2881 Value result = reduce_window_op.getResult(0);
2882
2883 if (element_type != sum_element_type) {
2884 // Convert back to original element type.
2885 result = rewriter.create<ConvertOp>(op.getLoc(), result, element_type);
2886 }
2887 rewriter.replaceOp(op, {result});
2888 return success();
2889 }
2890 };
2891
2892 using ConvertAvgPool2DGradOp =
2893 ConvertAvgPoolGradOp<TF::AvgPoolGradOp, /*num_dims=*/4>;
2894 using ConvertAvgPool3DGradOp =
2895 ConvertAvgPoolGradOp<TF::AvgPool3DGradOp, /*num_dims=*/5>;
2896
2897 // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window
2898 // dimensions with max as the reduction function.
2899 //
2900 // Sample result for VALID padding mode:
2901 //
2902 // %init = arith.constant dense<...> : tensor<i32>
2903 // %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
2904 // {window_dimensions = ..., window_strides = ... }
2905 //
2906 template <typename OpTy, int num_dims>
2907 class ConvertMaxPoolOp : public OpRewritePattern<OpTy> {
2908 public:
2909 using OpRewritePattern<OpTy>::OpRewritePattern;
2910
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const2911 LogicalResult matchAndRewrite(OpTy op,
2912 PatternRewriter &rewriter) const override {
2913 Type element_type =
2914 op.input().getType().template cast<TensorType>().getElementType();
2915 if (!element_type.isSignlessIntOrFloat()) return failure();
2916 tensorflow::Padding padding;
2917 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
2918 return failure();
2919 if (padding == tensorflow::Padding::EXPLICIT) {
2920 return failure();
2921 }
2922 Location loc = op.getLoc();
2923 ConstantOp init = GetScalarLimitConstOfType(
2924 element_type, loc, hlo::kInfinityLowest, &rewriter);
2925
2926 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
2927 if (!input_ty) return failure();
2928 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
2929 input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
2930 auto reduce = rewriter.create<ReduceWindowOp>(
2931 loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()),
2932 GetI64ElementsAttr(op.strides()),
2933 /*base_dilations=*/DenseIntElementsAttr(),
2934 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
2935 BuildReduceBody<MaxOp>(element_type, &reduce.body(), &rewriter);
2936
2937 rewriter.replaceOp(op, reduce.getResult(0));
2938 return success();
2939 }
2940 };
2941
2942 using ConvertMaxPool2DOp = ConvertMaxPoolOp<TF::MaxPoolOp, /*num_dims=*/4>;
2943 using ConvertMaxPool3DOp = ConvertMaxPoolOp<TF::MaxPool3DOp, /*num_dims=*/5>;
2944
2945 // Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on
2946 // the condition only.
2947 class ConvertSelectOp : public OpRewritePattern<TF::SelectOp> {
2948 public:
2949 using OpRewritePattern::OpRewritePattern;
2950
matchAndRewrite(TF::SelectOp op,PatternRewriter & rewriter) const2951 LogicalResult matchAndRewrite(TF::SelectOp op,
2952 PatternRewriter &rewriter) const override {
2953 // This lowering only works on ranked types.
2954 auto cond_type = op.condition().getType().dyn_cast<RankedTensorType>();
2955 auto then_type = op.t().getType().dyn_cast<RankedTensorType>();
2956 auto else_type = op.e().getType().dyn_cast<RankedTensorType>();
2957 if (!cond_type || !then_type || !else_type) {
2958 return failure();
2959 }
2960
2961 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
2962 Value cond_shape = b.createOrFold<shape::ShapeOfOp>(op.condition());
2963 Value then_shape = b.createOrFold<shape::ShapeOfOp>(op.t());
2964 Value else_shape = b.createOrFold<shape::ShapeOfOp>(op.e());
2965
2966 // First check that the `then` and `else` shapes are the equal.
2967 Value assumption =
2968 b.createOrFold<shape::CstrEqOp>(ValueRange{then_shape, else_shape});
2969 // For a vector cond we also verify that the majormost dim of `then` matches
2970 // the vector size. To do that split off the first dim of `then`.
2971 bool needs_broadcast = cond_type.getRank() == 1 && then_type.getRank() != 1;
2972 Value then_shape_split = then_shape;
2973 if (needs_broadcast) {
2974 Value const_one = b.create<arith::ConstantIndexOp>(1);
2975 Type extent_first = shape::getExtentTensorType(b.getContext(), 1);
2976 Type extent_second =
2977 shape::getExtentTensorType(b.getContext(), then_type.getRank() - 1);
2978 SmallVector<Value, 2> then_split;
2979 b.createOrFold<shape::SplitAtOp>(then_split,
2980 TypeRange{extent_first, extent_second},
2981 then_shape, const_one);
2982 then_shape_split = then_split[0];
2983 }
2984 // If the condition is not a scalar, check that it matches the other shapes.
2985 if (cond_type.getRank() > 0) {
2986 Value eq_cstr = b.createOrFold<shape::CstrEqOp>(
2987 ValueRange{cond_shape, then_shape_split});
2988 auto witness = shape::WitnessType::get(b.getContext());
2989 assumption = b.createOrFold<shape::AssumingAllOp>(
2990 witness, ValueRange{assumption, eq_cstr});
2991 }
2992 auto result_type = op.getResult().getType().cast<TensorType>();
2993 auto assuming_op =
2994 b.create<shape::AssumingOp>(ArrayRef<Type>{result_type}, assumption);
2995
2996 OpBuilder::InsertionGuard guard(b);
2997 b.createBlock(&assuming_op.getDoRegion());
2998
2999 // Broadcast the cond if necessary.
3000 Value cond = op.condition();
3001 if (needs_broadcast) {
3002 Value result_extents = b.create<shape::ToExtentTensorOp>(
3003 GetExtentsTensorTypeFor(result_type), then_shape);
3004 cond = b.create<mhlo::DynamicBroadcastInDimOp>(
3005 RankedTensorType::get(result_type.getShape(), b.getI1Type()), cond,
3006 result_extents, GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b));
3007 }
3008 Value select = b.create<mhlo::SelectOp>(result_type, cond, op.t(), op.e());
3009 b.create<shape::AssumingYieldOp>(select);
3010 rewriter.replaceOp(op, {assuming_op.getResult(0)});
3011 return success();
3012 }
3013 };
3014
3015 // Converts Sigmoid op to HLO ops computing sigmoid with the following formula:
3016 //
3017 // sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5)
3018 //
3019 // Sample result with 2-d f16 inputs with B batches of with N elements each.
3020 //
3021 // // Create an array of 0.5 the shape of the input array.
3022 // %half = mhlo.constant dense<5.000000e-01> : tensor<f32>
3023 // %half_array = "mhlo.broadcast"(half)
3024 // {broadcast_sizes = dense<2> : tensor<1xi64>}
3025 // : (tensor<f32>) -> tensor<2xf32>
3026 //
3027 // // Compute Tanh of half the logits of the values.
3028 // %halved_logits = mhlo.multiply %logits, %half_array : tensor<2xf32>
3029 // %tanh = "mhlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
3030 //
3031 // // Have the result of Tanh and add 0.5.
3032 // %halved_tanh = mhlo.multiply %tanh, %half : tensor<2xf32>
3033 // %sigmoid = mhlo.add %halved_tanh, %half : tensor<2xf32>
3034 //
3035 class ConvertSigmoidOp : public RewritePattern {
3036 public:
ConvertSigmoidOp(MLIRContext * context)3037 explicit ConvertSigmoidOp(MLIRContext *context)
3038 : RewritePattern(
3039 TF::SigmoidOp::getOperationName(), 0, context,
3040 {mhlo::ConstantOp::getOperationName(),
3041 shape::ShapeOfOp::getOperationName(),
3042 shape::ToExtentTensorOp::getOperationName(),
3043 mhlo::DynamicBroadcastInDimOp::getOperationName(),
3044 mhlo::MulOp::getOperationName(), mhlo::TanhOp::getOperationName(),
3045 mhlo::AddOp::getOperationName()}) {}
3046
matchAndRewrite(Operation * sigmoid_op,PatternRewriter & rewriter) const3047 LogicalResult matchAndRewrite(Operation *sigmoid_op,
3048 PatternRewriter &rewriter) const override {
3049 auto op = cast<TF::SigmoidOp>(sigmoid_op);
3050 Location loc = op.getLoc();
3051
3052 // Create constant half with shape and element type same as the operand.
3053 Value operand = op.getOperand();
3054 auto operand_ty = operand.getType().cast<TensorType>();
3055 auto scalar_ty = RankedTensorType::get({}, operand_ty.getElementType());
3056 ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5);
3057 auto scalar_half = rewriter.create<ConstantOp>(loc, attr);
3058 auto half = BroadcastToShapeOf(loc, scalar_half, operand, rewriter);
3059
3060 auto scaled_input = rewriter.create<MulOp>(loc, operand, half);
3061 auto tanh_op = rewriter.create<TanhOp>(loc, scaled_input);
3062 auto mul_op = rewriter.create<MulOp>(loc, tanh_op, half);
3063 auto add_op = rewriter.create<AddOp>(loc, mul_op, half);
3064
3065 rewriter.replaceOp(op, add_op.getResult());
3066 return success();
3067 }
3068 };
3069
3070 // Converts the tf.Slice op into mhlo.real_dynamic_slice
3071 // TODO(disc): To recover static special case's performance with folding and
3072 // canonicalization.
3073 class ConvertSliceOpDynamic : public OpRewritePattern<TF::SliceOp> {
3074 public:
3075 using OpRewritePattern::OpRewritePattern;
3076
matchAndRewrite(TF::SliceOp op,PatternRewriter & rewriter) const3077 LogicalResult matchAndRewrite(TF::SliceOp op,
3078 PatternRewriter &rewriter) const override {
3079 Location loc = op.getLoc();
3080 Value input = op.input();
3081 Value begin_indices = op.begin();
3082 Value sizes = op.size();
3083
3084 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
3085 auto begin_type = begin_indices.getType().dyn_cast<RankedTensorType>();
3086 auto size_type = sizes.getType().dyn_cast<RankedTensorType>();
3087
3088 if (!input_ty || !begin_type || !size_type ||
3089 !begin_type.hasStaticShape() || !size_type.hasStaticShape() ||
3090 begin_type.getRank() != 1 || size_type.getRank() != 1) {
3091 return failure();
3092 }
3093 // TODO(disc): remove static shape check once folding/canonicalization func
3094 // added
3095 DenseIntElementsAttr size_attr;
3096 if (matchPattern(op.size(), m_Constant(&size_attr))) {
3097 return failure();
3098 }
3099
3100 int rank = begin_type.getDimSize(0);
3101 auto shape_scalar_type = begin_type.getElementType();
3102 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
3103 SmallVector<Value, 4> stride_values(rank, one);
3104 SmallVector<Value, 4> end_values;
3105 SmallVector<Value, 4> begin_values;
3106 end_values.reserve(rank);
3107 for (int i = 0; i < rank; ++i) {
3108 SmallVector<Value, 4> indices;
3109 indices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i));
3110 auto begin_value =
3111 rewriter.create<tensor::ExtractOp>(loc, begin_indices, indices);
3112 auto size_value = rewriter.create<tensor::ExtractOp>(loc, sizes, indices);
3113 Value minus_one = rewriter.create<arith::IndexCastOp>(
3114 loc, shape_scalar_type,
3115 rewriter.create<arith::ConstantIndexOp>(loc, -1));
3116 auto is_minus_one = rewriter.create<arith::CmpIOp>(
3117 loc, arith::CmpIPredicate::eq, size_value, minus_one);
3118 Value end_value =
3119 rewriter.create<arith::AddIOp>(loc, begin_value, size_value);
3120 auto dim_value = rewriter.create<arith::IndexCastOp>(
3121 loc, shape_scalar_type,
3122 rewriter.create<tensor::DimOp>(loc, input, i));
3123 end_value = rewriter.create<mlir::arith::SelectOp>(loc, is_minus_one,
3124 dim_value, end_value);
3125 auto end_value_casted = rewriter.create<arith::IndexCastOp>(
3126 loc, rewriter.getIndexType(), end_value);
3127 end_values.push_back(end_value_casted);
3128
3129 auto begin_value_casted = rewriter.create<arith::IndexCastOp>(
3130 loc, rewriter.getIndexType(), begin_value);
3131 begin_values.push_back(begin_value_casted);
3132 }
3133 auto index_ty = rewriter.getIndexType();
3134 auto start_indices = rewriter.create<tensor::FromElementsOp>(
3135 loc,
3136 RankedTensorType::get({static_cast<int64_t>(begin_values.size())},
3137 index_ty),
3138 begin_values);
3139 auto end_indices = rewriter.create<tensor::FromElementsOp>(
3140 loc,
3141 RankedTensorType::get({static_cast<int64_t>(end_values.size())},
3142 index_ty),
3143 end_values);
3144 auto stride_indices = rewriter.create<tensor::FromElementsOp>(
3145 loc,
3146 RankedTensorType::get({static_cast<int64_t>(stride_values.size())},
3147 index_ty),
3148 stride_values);
3149
3150 auto d_slice = rewriter.create<mhlo::RealDynamicSliceOp>(
3151 loc, op.getOperation()->getResult(0).getType(), input, start_indices,
3152 end_indices, stride_indices);
3153 rewriter.replaceOp(op, d_slice.getOperation()->getResults());
3154 return success();
3155 }
3156 };
3157
BroadcastBatchMatMulV2Operands(Value lhs,Value rhs,Location loc,Value * out_lhs,Value * out_rhs,PatternRewriter * rewriter)3158 static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
3159 Value *out_lhs, Value *out_rhs,
3160 PatternRewriter *rewriter) {
3161 // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is:
3162 // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS]
3163 // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS]
3164 // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS]
3165 // To perform the matmul, we need to first broadcast lhs and rhs to a common
3166 // set of leading dimensions before doing the actual matmul.
3167 // That's what the code below does.
3168 // In particular, we populate out_lhs and out_rhs to have dimension structure:
3169 // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS]
3170 // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS]
3171 // To do this, we need to calculate those output shapes, which involves
3172 // slicing off the leading batch dims of each operand, broadcasting them,
3173 // then concatenating the broadcasted leading dims back to the row/col dims.
3174 // Finally, we create a TF::BroadcastTo op that does the actual broadcast.
3175
3176 // TODO(silvasean): Reduce duplication across reified shape calculations and
3177 // the static computation of output types needed to create ops.
3178 Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
3179 Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
3180 Value const_neg2 =
3181 rewriter->create<arith::ConstantOp>(loc, rewriter->getIndexAttr(-2));
3182 auto shape_type = shape::ShapeType::get(rewriter->getContext());
3183 auto lhs_splitted = rewriter->create<shape::SplitAtOp>(
3184 loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2);
3185 auto rhs_splitted = rewriter->create<shape::SplitAtOp>(
3186 loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2);
3187 auto lhs_type = lhs.getType().cast<RankedTensorType>();
3188 auto rhs_type = rhs.getType().cast<RankedTensorType>();
3189 // The last two dimensions are the matrix row/col dimensions. Don't broadcast
3190 // them.
3191 SmallVector<int64_t, 6> result_batch_shape_compile_time_extents;
3192 mlir::OpTrait::util::getBroadcastedShape(
3193 lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2),
3194 result_batch_shape_compile_time_extents);
3195 auto result_batch_shape = rewriter->create<shape::BroadcastOp>(
3196 loc, shape_type, lhs_splitted.getHead(), rhs_splitted.getHead(),
3197 /*error=*/nullptr);
3198 // Lambda which handles the broadcasting of one side to the common
3199 // leading-batch dimensions.
3200 auto broadcast_one_side = [&](Value side, RankedTensorType type,
3201 Value tail_shape, Value *out_side) {
3202 ArrayRef<int64_t> matrix_dims = type.getShape().take_back(2);
3203 auto result_shape = result_batch_shape_compile_time_extents;
3204 result_shape.append(matrix_dims.begin(), matrix_dims.end());
3205 auto result_type =
3206 RankedTensorType::get(result_shape, type.getElementType());
3207 auto shape =
3208 rewriter->create<shape::ConcatOp>(loc, shape_type, result_batch_shape, tail_shape);
3209 auto shape_tensor = rewriter->create<shape::ToExtentTensorOp>(
3210 loc,
3211 RankedTensorType::get({static_cast<int64_t>(result_shape.size())},
3212 rewriter->getIndexType()),
3213 shape);
3214 *out_side = rewriter->create<TF::BroadcastToOp>(loc, result_type, side,
3215 shape_tensor);
3216 };
3217 broadcast_one_side(lhs, lhs_type, lhs_splitted.getTail(), out_lhs);
3218 broadcast_one_side(rhs, rhs_type, rhs_splitted.getTail(), out_rhs);
3219 }
3220
3221 class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
3222 public:
3223 // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved
3224 // to CHLO and it is missing legalization to MHLO. Once that is done, this
3225 // pattern's benefit can be changed back to one as well as the fallback
3226 // lowering pattern for the op can be removed.
3227 //
3228 // Set benefit of this pattern to zero to prefer the fallback pattern when
3229 // available and applicable. That pattern avoids broadcast on operands and is
3230 // therefore faster.
3231 //
3232 // Native legalization for BatchMatMulV3 needs to be added as well.
ConvertBatchMatMulV2Op(MLIRContext * context)3233 explicit ConvertBatchMatMulV2Op(MLIRContext *context)
3234 : OpRewritePattern<TF::BatchMatMulV2Op>(context, /*benefit=*/0) {}
3235
matchAndRewrite(TF::BatchMatMulV2Op op,PatternRewriter & rewriter) const3236 LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
3237 PatternRewriter &rewriter) const override {
3238 Value lhs = op.x();
3239 Value rhs = op.y();
3240 auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
3241 auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
3242 if (!lhs_type || !rhs_type) return failure();
3243 if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
3244 lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
3245 }
3246 if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
3247 rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
3248 }
3249
3250 // Broadcast both operands.
3251 BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs,
3252 &rewriter);
3253 lhs_type = lhs.getType().cast<RankedTensorType>();
3254 rhs_type = rhs.getType().cast<RankedTensorType>();
3255 assert(lhs_type.getRank() == rhs_type.getRank());
3256 int64_t rank = lhs_type.getRank();
3257 auto batch_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2));
3258 auto lhs_contracting_dimensions = llvm::to_vector<4>(
3259 llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}));
3260 auto rhs_contracting_dimensions = llvm::to_vector<4>(
3261 llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}));
3262 auto dimension_numbers = DotDimensionNumbersAttr::get(
3263 rewriter.getContext(),
3264 /*lhs_batching_dimensions=*/batch_dimensions,
3265 /*rhs_batching_dimensions=*/batch_dimensions,
3266 /*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
3267 /*rhs_contracting_dimensions=*/rhs_contracting_dimensions);
3268 // TODO(silvasean): Emit shape checks for contracting dimensions.
3269 // (The batch dimensions are checked by the broadcasting logic)
3270 rewriter.replaceOpWithNewOp<DotGeneralOp>(op, op.getType(), lhs, rhs,
3271 dimension_numbers,
3272 /*precision_config=*/nullptr);
3273 return success();
3274 }
3275 };
3276
3277 // Converts the tf.Split op into a series of HLO slice ops when the tensor to be
3278 // split has fully static shape and the dimension to split is a constant.
3279 //
3280 // The main logic of this pattern is to calculate the index start and end range
3281 // for each slice. And this happens only on the dimension to be split; for all
3282 // other dimensions, all resultant slices' index start and end range covers the
3283 // input tensor's full range. Strides for all resultant slices are all one.
3284 //
3285 // For example, the following source IR:
3286 //
3287 // %dim = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
3288 // %0:3 = "tf.Split"(%dim, %input) : (tensor<i32>, tensor<4x6xf32>) ->
3289 // (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>)
3290 //
3291 // will be converted into:
3292 //
3293 // %0 = "mhlo.slice"(%input) {
3294 // limit_indices = dense<[4, 2]> : tensor<2xi64>,
3295 // start_indices = dense<0> : tensor<2xi64>,
3296 // strides = dense<1> : tensor<2xi64>} :
3297 // (tensor<4x6xf32>) -> tensor<4x2xf32>
3298 // %1 = "mhlo.slice"(%input) {
3299 // limit_indices = dense<4> : tensor<2xi64>,
3300 // start_indices = dense<[0, 2]> : tensor<2xi64>,
3301 // strides = dense<1> : tensor<2xi64>} :
3302 // (tensor<4x6xf32>) -> tensor<4x2xf32>
3303 // %2 = "mhlo.slice"(%input) {
3304 // limit_indices = dense<[4, 6]> : tensor<2xi64>,
3305 // start_indices = dense<[0, 4]> : tensor<2xi64>,
3306 // strides = dense<1> : tensor<2xi64>} :
3307 // (tensor<4x6xf32>) -> tensor<4x2xf32>
3308 // TODO(antiagainst): consider lowering into TF ops so the pattern can be more
3309 // applicable.
3310 class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
3311 public:
3312 using OpRewritePattern::OpRewritePattern;
3313
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const3314 LogicalResult matchAndRewrite(TF::SplitOp op,
3315 PatternRewriter &rewriter) const override {
3316 // We can only split along static dimensions.
3317 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3318 if (!input_type) return failure();
3319
3320 // We can only match when the split dimension is a constant scalar.
3321 DenseIntElementsAttr split_dim_attr;
3322 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3323 return failure();
3324
3325 // Get the dimension we are splitting at. Offset properly if it's negative.
3326 int64_t input_rank = input_type.getRank();
3327 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3328 if (dim_index < 0) dim_index += input_rank;
3329
3330 // Calculate the dimension size for each slice along the split dimension.
3331 int64_t input_dim_size = input_type.getDimSize(dim_index);
3332 // If we are splitting along the dynamic dimension then we cannot compute
3333 // the static dimension length.
3334 if (ShapedType::isDynamic(input_dim_size)) return failure();
3335
3336 int64_t num_splits = op.getNumResults();
3337 int64_t slice_size = input_dim_size / num_splits;
3338
3339 // Get each slice's type.
3340 auto slice_shape = llvm::to_vector<4>(input_type.getShape());
3341 slice_shape[dim_index] = slice_size;
3342 Type slice_type =
3343 RankedTensorType::get(slice_shape, input_type.getElementType());
3344
3345 // Parameters for constructing each slice.
3346 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3347 auto end_indices = llvm::to_vector<4>(input_type.getShape());
3348 SmallVector<int64_t, 4> strides(input_rank, 1);
3349
3350 // All HLO slice results used to replace the original tf.Split op.
3351 SmallVector<Value, 4> slices;
3352 slices.reserve(num_splits);
3353
3354 for (int i = 0; i < num_splits; ++i) {
3355 begin_indices[dim_index] = i * slice_size;
3356 end_indices[dim_index] = (i + 1) * slice_size;
3357 slices.push_back(
3358 rewriter.create<SliceOp>(op.getLoc(), slice_type, op.value(),
3359 GetI64ElementsAttr(begin_indices, &rewriter),
3360 GetI64ElementsAttr(end_indices, &rewriter),
3361 GetI64ElementsAttr(strides, &rewriter)));
3362 }
3363
3364 rewriter.replaceOp(op, slices);
3365 return success();
3366 }
3367 };
3368
3369 // Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the
3370 // dimension to split is a constant.
3371 // TODO(disc): To recover static special case's performance with folding and
3372 // canonicalization. delete ConvertSplitOp
3373 class ConvertSplitOpDynamic : public OpRewritePattern<TF::SplitOp> {
3374 public:
3375 using OpRewritePattern::OpRewritePattern;
3376
matchAndRewrite(TF::SplitOp op,PatternRewriter & rewriter) const3377 LogicalResult matchAndRewrite(TF::SplitOp op,
3378 PatternRewriter &rewriter) const override {
3379 Location loc = op.getLoc();
3380 Value input = op.value();
3381 auto input_type = input.getType().dyn_cast<RankedTensorType>();
3382 if (!input_type) return failure();
3383 // We can only match when the split dimension is a constant scalar.
3384 DenseIntElementsAttr split_dim_attr;
3385 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3386 return failure();
3387
3388 // Get the dimension we are splitting at. Offset properly if it's negative.
3389 int64_t input_rank = input_type.getRank();
3390 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3391 if (dim_index < 0) dim_index += input_rank;
3392
3393 // TODO(disc): remove static shape check once folding/canonicalization func
3394 // added and ConvertSplitOp deleted. Calculate the dimension size for each
3395 // slice along the split dimension. We are splitting along the dynamic
3396 // dimension, or using static pattern transform
3397 int64_t c_input_dim_size = input_type.getDimSize(dim_index);
3398 if (!ShapedType::isDynamic(c_input_dim_size)) return failure();
3399
3400 Value input_dim_size =
3401 rewriter.create<tensor::DimOp>(loc, input, dim_index);
3402 // Calculate the dimension size for each slice along the split dimension.
3403 int num_splits = op.getNumResults();
3404 Value num_splits_value = rewriter.create<arith::ConstantOp>(
3405 loc, rewriter.getIndexAttr(num_splits));
3406 Value slice_size =
3407 rewriter.create<arith::DivSIOp>(loc, input_dim_size, num_splits_value);
3408
3409 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
3410 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
3411
3412 SmallVector<Value, 4> begin_indices(input_rank, zero);
3413 SmallVector<Value, 4> end_indices;
3414 end_indices.reserve(input_rank);
3415 SmallVector<Value, 4> strides(input_rank, one);
3416 for (int i = 0; i < input_rank; ++i) {
3417 end_indices.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
3418 }
3419
3420 // All HLO d_slice results used to replace the original tf.Split op.
3421 SmallVector<Value, 4> slices;
3422 slices.reserve(num_splits);
3423
3424 for (int i = 0; i < num_splits; ++i) {
3425 begin_indices[dim_index] = rewriter.create<arith::MulIOp>(
3426 loc, slice_size, rewriter.create<arith::ConstantIndexOp>(loc, i));
3427 end_indices[dim_index] = rewriter.create<arith::MulIOp>(
3428 loc, slice_size, rewriter.create<arith::ConstantIndexOp>(loc, i + 1));
3429
3430 Type index_ty = rewriter.getIndexType();
3431 auto begin_value = rewriter.create<tensor::FromElementsOp>(
3432 loc,
3433 RankedTensorType::get({static_cast<int64_t>(begin_indices.size())},
3434 index_ty),
3435 begin_indices);
3436 auto end_value = rewriter.create<tensor::FromElementsOp>(
3437 loc,
3438 RankedTensorType::get({static_cast<int64_t>(end_indices.size())},
3439 index_ty),
3440 end_indices);
3441 auto stride_value = rewriter.create<tensor::FromElementsOp>(
3442 loc,
3443 RankedTensorType::get({static_cast<int64_t>(strides.size())},
3444 index_ty),
3445 strides);
3446 slices.push_back(rewriter.create<RealDynamicSliceOp>(
3447 loc, op.getOperation()->getResult(i).getType(), input, begin_value,
3448 end_value, stride_value));
3449 }
3450
3451 rewriter.replaceOp(op, slices);
3452 return success();
3453 }
3454 };
3455
3456 // Converts the tf.SplitV op into a series of HLO slice ops when the tensor to
3457 // be split has fully static shape and the dimension to split and split sizes
3458 // are constants.
3459 //
3460 // This is similar to the conversion for tf.Split op other than that the size of
3461 // each chunk on the dimension to split is explicitly given as an op operand
3462 // and they are not necessarily the same.
3463 //
3464 // For example, given the following IR:
3465 //
3466 // %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>}
3467 // %split_dim = "tf.Const"() {value = dense<1> : tensor<i32>}
3468 // %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) :
3469 // (tensor<4x6xf32>, tensor<3xi32>, tensor<i32>) ->
3470 // (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
3471 //
3472 // We will generate slices following slices:
3473 // %0 = "mhlo.slice"(%input) {
3474 // limit_indices = dense<[4, 1]> : tensor<2xi64>,
3475 // start_indices = dense<0> : tensor<2xi64>,
3476 // strides = dense<1> : tensor<2xi64>} :
3477 // (tensor<4x6xf32>) -> tensor<4x1xf32>
3478 // %1 = "mhlo.slice"(%input) {
3479 // limit_indices = dense<[4, 3]> : tensor<2xi64>,
3480 // start_indices = dense<[0, 1]> : tensor<2xi64>,
3481 // strides = dense<1> : tensor<2xi64>} :
3482 // (tensor<4x6xf32>) -> tensor<4x2xf32>
3483 // %2 = "mhlo.slice"(%input) {
3484 // limit_indices = dense<[4, 6]> : tensor<2xi64>,
3485 // start_indices = dense<[0, 3]> : tensor<2xi64>,
3486 // strides = dense<1> : tensor<2xi64>} :
3487 // (tensor<4x6xf32>) -> tensor<4x3xf32>
3488 class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
3489 public:
3490 using OpRewritePattern::OpRewritePattern;
3491
matchAndRewrite(TF::SplitVOp op,PatternRewriter & rewriter) const3492 LogicalResult matchAndRewrite(TF::SplitVOp op,
3493 PatternRewriter &rewriter) const override {
3494 // We can only split along static dimensions.
3495 // TODO(b/145731001): enhance to support dynamic-shaped inputs.
3496 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
3497 if (!input_type) return failure();
3498
3499 // We can only match when the split dimension is a constant scalar.
3500 DenseIntElementsAttr split_dim_attr;
3501 if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr)))
3502 return failure();
3503
3504 // We can only match when the split sizes is a constant int vector.
3505 DenseIntElementsAttr split_sizes_attr;
3506 if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr)))
3507 return failure();
3508
3509 // Get each chunck's size along the dimension to split. It may contain
3510 // dynamic sizes and we need to update it if so.
3511 SmallVector<int64_t, 4> split_sizes;
3512 int64_t total_dim_size = 0; // Total dimension size assigned to splits
3513 llvm::Optional<int> dynamic_dim_index;
3514 split_sizes.reserve(
3515 split_sizes_attr.getType().cast<ShapedType>().getNumElements());
3516 for (auto dim : llvm::enumerate(split_sizes_attr)) {
3517 int64_t dim_val = dim.value().getSExtValue();
3518 split_sizes.push_back(dim_val);
3519 if (dim_val == ShapedType::kDynamicSize) {
3520 // We cannot have more than one dynamic dimension.
3521 assert(!dynamic_dim_index && "invalid split sizes");
3522 dynamic_dim_index = dim.index();
3523 } else {
3524 total_dim_size += dim_val;
3525 }
3526 }
3527
3528 // Get the dimension we are splitting at. Offset properly if it's negative.
3529 int64_t input_rank = input_type.getRank();
3530 int64_t dim_index = (*split_dim_attr.begin()).getSExtValue();
3531 if (dim_index < 0) dim_index += input_rank;
3532
3533 int64_t input_dim_size = input_type.getDimSize(dim_index);
3534 if (ShapedType::isDynamic(input_dim_size)) return failure();
3535
3536 assert(((dynamic_dim_index && total_dim_size <= input_dim_size) ||
3537 (!dynamic_dim_index && total_dim_size == input_dim_size)) &&
3538 "invalid split sizes");
3539
3540 // Update the dynamic dimension with calculated concrete size.
3541 if (dynamic_dim_index)
3542 split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size;
3543
3544 // Parameters for constructing each slice.
3545 SmallVector<int64_t, 4> begin_indices(input_rank, 0);
3546 auto end_indices = llvm::to_vector<4>(input_type.getShape());
3547 SmallVector<int64_t, 4> strides(input_rank, 1);
3548
3549 // All HLO slice results used to replace the original tf.Split op.
3550 SmallVector<Value, 4> slices;
3551 slices.reserve(op.getNumResults());
3552
3553 for (int i = 0, end = op.getNumResults(); i < end; ++i) {
3554 end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
3555 slices.push_back(rewriter.create<mhlo::SliceOp>(
3556 op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
3557 GetI64ElementsAttr(end_indices, &rewriter),
3558 GetI64ElementsAttr(strides, &rewriter)));
3559 // Prepare the begin indice for the next slice.
3560 begin_indices[dim_index] = end_indices[dim_index];
3561 }
3562
3563 rewriter.replaceOp(op, slices);
3564 return success();
3565 }
3566 };
3567
3568 // Converts StridedSlice op to HLO Slice op along with Reverse op to handle
3569 // negative strides and Reshape op to update the output shape. Indices and
3570 // strides operands are converted to attributes with non-negative indexing.
3571 //
3572 // If the begin input is not a compile time constant, the begin input needs to
3573 // be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this
3574 // case, strides must have a known value of 1 (otherwise we have insufficient
3575 // information to conform to XLA's op semantics).
3576 //
3577 // For example with an op like following,
3578 // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1}
3579 // : tensor<AxBxf32> -> tensor<Pxf32>
3580 //
3581 // If the %begin input is constant, output would be:
3582 // %reversed = "mhlo.Reverse" (%input) {dimensions = ...}
3583 // %sliced = "mhlo.Slice" (%input)
3584 // {start_indices = ..., limit_indices = ..., strides = ...}
3585 // %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
3586 //
3587 class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
3588 public:
3589 using OpRewritePattern::OpRewritePattern;
3590
rewriteWithConstantBegin(TF::StridedSliceOp op,ArrayRef<int64_t> begin_indices,ArrayRef<int64_t> end_indices,ArrayRef<int64_t> strides,RankedTensorType input_ty,PatternRewriter & rewriter) const3591 LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op,
3592 ArrayRef<int64_t> begin_indices,
3593 ArrayRef<int64_t> end_indices,
3594 ArrayRef<int64_t> strides,
3595 RankedTensorType input_ty,
3596 PatternRewriter &rewriter) const {
3597 SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides,
3598 dims_to_reverse;
3599 int64_t input_rank = input_ty.getRank();
3600 ArrayRef<int64_t> input_shape = input_ty.getShape();
3601 hlo_begin_indices.reserve(input_rank);
3602 hlo_end_indices.reserve(input_rank);
3603 hlo_strides.reserve(input_rank);
3604
3605 int64_t indices_elements = begin_indices.size();
3606 if (input_rank < indices_elements) return failure();
3607
3608 // Convert from TensorFlow negative or out of range indices and strides
3609 // values to legal HLO Slice attributes.
3610 for (int i = 0, e = indices_elements; i != e; i++) {
3611 int64_t begin = begin_indices[i];
3612 int64_t end = end_indices[i];
3613 int64_t stride = strides[i];
3614
3615 if (stride < 0) {
3616 // Negative stride means that the output values are computed starting
3617 // from end until begin. Mark the dimension for reversal before slice
3618 // and compute indices for the reversed input.
3619 dims_to_reverse.push_back(i);
3620 begin = (input_shape[i] - 1) - begin;
3621 end = (input_shape[i] - 1) - end;
3622 stride = -stride;
3623 }
3624
3625 // Unlike TensorFlow, HLO requires begin and end values to be within
3626 // range.
3627 begin = std::max(int64_t(0), begin);
3628 end = std::max(begin, end);
3629 end = std::min(end, input_shape[i]);
3630
3631 hlo_begin_indices.push_back(begin);
3632 hlo_end_indices.push_back(end);
3633 hlo_strides.push_back(stride);
3634 }
3635
3636 Location loc = op.getLoc();
3637 Value input = op.input();
3638 if (!dims_to_reverse.empty())
3639 input = rewriter.create<ReverseOp>(
3640 loc, input_ty, op.input(),
3641 GetI64ElementsAttr(dims_to_reverse, &rewriter));
3642 auto sliced = rewriter.create<SliceOp>(
3643 loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter),
3644 GetI64ElementsAttr(hlo_end_indices, &rewriter),
3645 GetI64ElementsAttr(hlo_strides, &rewriter));
3646
3647 // Reshape slice result so that the shape is updated depending on
3648 // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3649 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3650 return success();
3651 }
3652
rewriteWithUnknownBegin(TF::StridedSliceOp op,RankedTensorType input_ty,RankedTensorType result_ty,PatternRewriter & rewriter) const3653 LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op,
3654 RankedTensorType input_ty,
3655 RankedTensorType result_ty,
3656 PatternRewriter &rewriter) const {
3657 // If begin and end values are dynamic, we can only support this lowering
3658 // if strides are a known value of 1.
3659 DenseIntElementsAttr sparse_strides_attr;
3660 if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) {
3661 return rewriter.notifyMatchFailure(
3662 op,
3663 "requires that strides are known when begin/end values are dynamic");
3664 }
3665 SmallVector<int64_t, 4> strides;
3666 int64_t stride_value;
3667 for (const APInt &stride : sparse_strides_attr) {
3668 if ((stride_value = stride.getSExtValue()) != 1) {
3669 return rewriter.notifyMatchFailure(op,
3670 "requires that strides are all 1 "
3671 "when begin/end values are dynamic");
3672 }
3673 strides.push_back(stride_value);
3674 }
3675
3676 ArrayRef<int64_t> input_shape = input_ty.getShape();
3677 int last_dim = std::max(static_cast<int>(input_shape.size()) - 1, 0);
3678
3679 // When begin/end values are dynamic, we can only support shrinking a major
3680 // axis. For instance, if there are 4 dims, we can support a
3681 // shrink_axis_mask of 0001 (1), 0011 (3), 0111 (7), or 1111 (15), but no
3682 // other.
3683 bool shrink_axis_mask_ok = llvm::isMask_64(op.shrink_axis_mask());
3684 if (!shrink_axis_mask_ok)
3685 return rewriter.notifyMatchFailure(
3686 op,
3687 "requires that shrink_axis_mask, if set, refer to a major axis "
3688 "dimension (when begin/end values are dynamic)");
3689
3690 // When begin/end values are dynamic, the ellipsis mask, if set, must refer
3691 // to the last dimension.
3692 int ellipsis_mask = op.ellipsis_mask();
3693 if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim)))
3694 return rewriter.notifyMatchFailure(
3695 op,
3696 "requires that ellipsis_mask, if set, refer to the last dimension of "
3697 "input (when begin/end values are dynamic)");
3698
3699 uint64_t new_axis_mask = op.new_axis_mask();
3700 if (new_axis_mask)
3701 return rewriter.notifyMatchFailure(
3702 op,
3703 "requires that new_axis_mask is either set to 0 or not set when "
3704 "begin/end values are dynamic");
3705
3706 // In this case where the begin and end values are dynamic, we only support
3707 // cases where the number of output elements has to be equal to the number
3708 // of input elements that are sliced. Each dimension is either sliced fully
3709 // or sliced with a size of one.
3710 int output_elements = result_ty.getNumElements();
3711 int input_elements_sliced = 1;
3712
3713 // Begin must be a ranked, 1-dimensional tensor: This is checked by the
3714 // verifier.
3715 int64_t slicing_dim_size =
3716 op.begin().getType().cast<RankedTensorType>().getDimSize(0);
3717 uint64_t begin_mask = op.begin_mask();
3718 uint64_t end_mask = op.end_mask();
3719 const int input_rank = input_shape.size();
3720 for (int d = 0; d < input_rank; ++d) {
3721 // Each dimension is either sliced fully or has size of one.
3722 if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) ||
3723 (d >= slicing_dim_size)) {
3724 input_elements_sliced *= input_shape[d];
3725 }
3726 }
3727 if (input_elements_sliced != output_elements) {
3728 return rewriter.notifyMatchFailure(
3729 op,
3730 "requires the number of output elements to be equal to the number of "
3731 "input elements sliced (when begin/end values are dynamic)");
3732 }
3733
3734 SmallVector<Value, 4> slice_begin_indices;
3735 // For the dimensions that are to be sliced, all have slice sizes of 1.
3736 SmallVector<int64_t, 4> slice_sizes;
3737 auto begin_element_ty =
3738 op.begin().getType().cast<ShapedType>().getElementType();
3739 // Scalar tensor type.
3740 TensorType type = RankedTensorType::get(/*shape=*/{}, begin_element_ty);
3741 Location loc = op.getLoc();
3742 auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter);
3743 for (int d = 0; d < input_rank; ++d) {
3744 if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) ||
3745 (d >= slicing_dim_size)) {
3746 slice_begin_indices.push_back(zero);
3747 slice_sizes.push_back(input_shape[d]);
3748 continue;
3749 }
3750
3751 auto index = rewriter.create<SliceOp>(
3752 loc, op.begin(), GetI64ElementsAttr({d}, &rewriter),
3753 GetI64ElementsAttr({d + 1}, &rewriter),
3754 GetI64ElementsAttr({1}, &rewriter));
3755 // Convert index to scalar.
3756 auto reshaped_index = rewriter.create<ReshapeOp>(loc, type, index);
3757 // If the index is negative, wrap it around with dimension size.
3758 auto index_negative =
3759 rewriter.create<TF::LessOp>(loc, reshaped_index, zero);
3760 auto input_val = GetScalarConstOfType(begin_element_ty, loc,
3761 input_shape[d], &rewriter);
3762 auto wrapped_index =
3763 rewriter.create<TF::AddV2Op>(loc, input_val, reshaped_index);
3764 auto final_index = rewriter.create<SelectOp>(
3765 loc, type, index_negative, wrapped_index, reshaped_index);
3766 slice_begin_indices.push_back(final_index);
3767 slice_sizes.push_back(1);
3768 }
3769
3770 auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter);
3771 auto sliced_type =
3772 RankedTensorType::get(slice_sizes, op.getType().getElementType());
3773 // This must be an xla DynamicSlice op due to the inputs that aren't
3774 // constant.
3775 auto sliced = rewriter.create<DynamicSliceOp>(
3776 loc, sliced_type, op.input(), slice_begin_indices, slice_sizes_attr);
3777
3778 // Reshape slice result so that the shape is updated depending on
3779 // 'new_axis_mask' or 'shrink_axis_mask' attributes.
3780 rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), sliced);
3781 return success();
3782 }
3783
matchAndRewrite(TF::StridedSliceOp op,PatternRewriter & rewriter) const3784 LogicalResult matchAndRewrite(TF::StridedSliceOp op,
3785 PatternRewriter &rewriter) const override {
3786 // Input shape needs to be static to convert negative indices in TensorFlow
3787 // to absolute indices required by HLO.
3788 //
3789 // TODO(hinsu): Relax this constraint for ops without negative indices and
3790 // strides.
3791 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
3792 if (!input_ty || !input_ty.hasStaticShape()) return failure();
3793
3794 // Output shape needs to be static to apply 'new_axis_mask' or
3795 // 'shrink_axis_mask' by reshaping tensor after slice.
3796 //
3797 // TODO(hinsu): Relax this constraint for ops without the above masks.
3798 auto result_ty = op.getType().dyn_cast<RankedTensorType>();
3799 if (!result_ty || !result_ty.hasStaticShape()) return failure();
3800
3801 DenseIntElementsAttr sparse_begin_attr, sparse_end_attr;
3802 if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) ||
3803 !matchPattern(op.end(), m_Constant(&sparse_end_attr))) {
3804 return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter);
3805 }
3806
3807 SmallVector<int64_t, 4> begin_indices, end_indices, strides;
3808 if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) {
3809 return failure();
3810 }
3811 return rewriteWithConstantBegin(op, begin_indices, end_indices, strides,
3812 input_ty, rewriter);
3813 }
3814 };
3815
3816 // Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops.
3817 //
3818 // tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does
3819 // the reverse: it propagates the graident for the sliced tensor to the original
3820 // input tensor by doing padding with zeros. The main logic is calculating the
3821 // indices and strides for padding.
3822 class ConvertStridedSliceGradOp
3823 : public OpRewritePattern<TF::StridedSliceGradOp> {
3824 public:
3825 using OpRewritePattern::OpRewritePattern;
3826
matchAndRewrite(TF::StridedSliceGradOp op,PatternRewriter & rewriter) const3827 LogicalResult matchAndRewrite(TF::StridedSliceGradOp op,
3828 PatternRewriter &rewriter) const override {
3829 // We need constant input shape to perform padding calculations later.
3830 DenseIntElementsAttr input_shape_attr;
3831 if (!matchPattern(op.shape(), m_Constant(&input_shape_attr)))
3832 return failure();
3833
3834 // We also need constant begin/end indices and strides to perform padding
3835 // calculations.
3836 // Bounded shape after performing strided slice
3837 SmallVector<int64_t, 4> shape;
3838 // Bounded begin, end, and strides for strided slice
3839 SmallVector<int64_t, 4> begin_indices, end_indices, strides;
3840 if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices,
3841 &strides))
3842 return failure();
3843
3844 Value grad = op.dy();
3845 Type element_type = grad.getType().cast<ShapedType>().getElementType();
3846
3847 // Perform reshape to undo any new/shrink axes done by strided slice.
3848 grad = rewriter.create<mhlo::ReshapeOp>(
3849 op.getLoc(), RankedTensorType::get(shape, element_type), grad);
3850
3851 SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
3852 SmallVector<int64_t, 4> dims_to_reverse;
3853 padding_low.reserve(shape.size());
3854 padding_high.reserve(shape.size());
3855 padding_interm.reserve(shape.size());
3856
3857 // Prepare padding parameters for each dimension.
3858 for (int i = 0, e = shape.size(); i < e; ++i) {
3859 int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue();
3860 if (strides[i] > 0) {
3861 padding_low.push_back(begin_indices[i]);
3862 padding_interm.push_back(strides[i] - 1);
3863
3864 // Pad the upper dimension up to the expected input shape. It's not
3865 // sufficient simply to use end_indices[i] to compute the padding in
3866 // cases where the stride does not divide evenly into the interval
3867 // between begin_indices[i] and end_indices[i].
3868 int64_t size =
3869 padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
3870 padding_high.push_back(input_dim - size);
3871 } else {
3872 dims_to_reverse.push_back(i);
3873 padding_high.push_back(input_dim - begin_indices[i] - 1);
3874 padding_interm.push_back(-strides[i] - 1);
3875
3876 // Pad the lower dimension up to the expected input shape.
3877 int64_t size =
3878 padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i];
3879 padding_low.push_back(input_dim - size);
3880 }
3881 }
3882
3883 if (!dims_to_reverse.empty()) {
3884 grad = rewriter.create<mhlo::ReverseOp>(
3885 op.getLoc(), grad.getType(), grad,
3886 GetI64ElementsAttr(dims_to_reverse, &rewriter));
3887 }
3888
3889 auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
3890 rewriter.replaceOpWithNewOp<mhlo::PadOp>(
3891 op, op.getType(), grad, zero,
3892 GetI64ElementsAttr(padding_low, &rewriter),
3893 GetI64ElementsAttr(padding_high, &rewriter),
3894 GetI64ElementsAttr(padding_interm, &rewriter));
3895 return success();
3896 }
3897 };
3898
3899 /// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and
3900 /// offset applied to generate the range values. The output tensor needs to
3901 /// have a static shape.
3902 ///
3903 /// For example an op like the following:
3904 /// %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"}
3905 /// : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
3906 ///
3907 /// Output would be:
3908 /// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
3909 /// %scaled = "mhlo.multiply"(%iota, %delta)
3910 /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
3911 /// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
3912 /// %result = "mhlo.add"(%scaled, %offset)
3913 /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
3914 /// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
3915 ///
3916 /// Implementation is defined in C++ due to no type interface for the iota op.
3917 class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
3918 using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
3919
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const3920 LogicalResult matchAndRewrite(TF::RangeOp op,
3921 PatternRewriter &rewriter) const override {
3922 auto result = op.getResult();
3923 auto result_type = result.getType();
3924 if (!result_type.cast<ShapedType>().hasStaticShape()) {
3925 return failure();
3926 }
3927
3928 auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
3929 rewriter.getI64IntegerAttr(0));
3930 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
3931 op.getLoc(), result_type, iota, op.delta(),
3932 hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
3933 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
3934 op, result_type, scaled, op.start(),
3935 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
3936 return success();
3937 }
3938 };
3939
3940 // Converts RangeOp for cases with the length is a dynamic value. The shape of
3941 // the resulting tensor computed, then the start and delta is used with the
3942 // dynamic_iota value to compute the final range value.
3943 //
3944 // For example, the resulting range op value:
3945 // %range = "tf.range"(%start, %limit, %delta)
3946 //
3947 // Is converted to the following.
3948 // %start + %delta * iota(ceil(abs((%limit - %start) / %delta))
3949 //
3950 // Implementation is defined in C++ due to the complicated type behavior.
3951 class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
3952 using OpRewritePattern<TF::RangeOp>::OpRewritePattern;
3953
matchAndRewrite(TF::RangeOp op,PatternRewriter & rewriter) const3954 LogicalResult matchAndRewrite(TF::RangeOp op,
3955 PatternRewriter &rewriter) const override {
3956 auto result = op.getResult();
3957 auto result_type = result.getType().cast<ShapedType>();
3958 if (result_type.hasStaticShape()) {
3959 return failure();
3960 }
3961
3962 Value start = op.start();
3963 Value delta = op.delta();
3964 Value limit = op.limit();
3965
3966 // To compute the length we need to use floating point calculations so that
3967 // ceil can be computed for the number of steps.
3968 auto compute_element_type =
3969 getElementTypeOrSelf(start.getType()).isa<FloatType>()
3970 ? getElementTypeOrSelf(start.getType())
3971 : rewriter.getF64Type();
3972 auto compute_type = RankedTensorType::get(
3973 limit.getType().cast<ShapedType>().getShape(), compute_element_type);
3974
3975 // Compute the length of the sequence we are going to need. This includes
3976 // some conversion to float for the operations.
3977 //
3978 // %size = ceil(abs((%limit - %start) / %delta))
3979 auto range = rewriter.create<mhlo::SubtractOp>(op.getLoc(), limit, start);
3980 auto abs = rewriter.create<mhlo::AbsOp>(op.getLoc(), range);
3981
3982 // Delta is not necessarily the same type as start and limit.
3983 auto abs_cast =
3984 rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, abs);
3985 auto delta_cast =
3986 rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, delta);
3987
3988 // Compute the total number of integer steps and convert to the HLO
3989 // dimension tensor.
3990 auto normalized =
3991 rewriter.create<mhlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
3992 auto ceil = rewriter.create<mhlo::CeilOp>(op.getLoc(), normalized);
3993 auto steps = rewriter.create<mhlo::ConvertOp>(
3994 op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil);
3995 auto reshape = rewriter.create<mhlo::ReshapeOp>(
3996 op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps);
3997
3998 // Using the resulting length compute the correct range value:
3999 //
4000 // %range = %start + %delta * iota(%size)
4001 auto out_scalar_type =
4002 RankedTensorType::get({}, getElementTypeOrSelf(result_type));
4003 auto start_out_cast =
4004 rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, start);
4005 auto delta_out_cast =
4006 rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, delta);
4007
4008 auto iota = rewriter.create<DynamicIotaOp>(
4009 op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
4010 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4011 op.getLoc(), result_type, iota, delta_out_cast,
4012 hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
4013 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4014 op, result_type, scaled, start_out_cast,
4015 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
4016 return success();
4017 }
4018 };
4019
ConvertAxisAttr(Value val,ElementsAttr attr,Builder * builder)4020 ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
4021 auto int_attr = attr.cast<DenseIntElementsAttr>();
4022 auto type = val.getType().cast<ShapedType>();
4023
4024 SmallVector<int64_t, 6> axis;
4025 axis.reserve(int_attr.getNumElements());
4026
4027 int64_t rank = type.getRank();
4028 for (auto val : int_attr.getValues<APInt>()) {
4029 axis.push_back((val.getSExtValue() + rank) % rank);
4030 }
4031
4032 return builder->getI64TensorAttr(axis);
4033 }
4034
4035 /// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling
4036 /// and offset applied to generate the linspace values. The output tensor needs
4037 /// to have a static shape. The implementation is defined in C++ because there
4038 /// is no type inference for the iota op.
4039 class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
4040 using OpRewritePattern<TF::LinSpaceOp>::OpRewritePattern;
4041
matchAndRewrite(TF::LinSpaceOp op,PatternRewriter & rewriter) const4042 LogicalResult matchAndRewrite(TF::LinSpaceOp op,
4043 PatternRewriter &rewriter) const override {
4044 auto result = op.getResult();
4045 auto result_type = result.getType().dyn_cast<ShapedType>();
4046 if (!result_type || !result_type.hasStaticShape()) {
4047 return failure();
4048 }
4049
4050 DenseIntElementsAttr num_attr;
4051 if (!matchPattern(op.num(), m_Constant(&num_attr))) {
4052 return rewriter.notifyMatchFailure(op, "Num must be a constant scalar");
4053 }
4054
4055 if (num_attr.begin() == num_attr.end()) {
4056 return rewriter.notifyMatchFailure(op, "Num must not be empty");
4057 }
4058 int64_t num = (*num_attr.begin()).getSExtValue();
4059
4060 // Calculate the scaling that needs to be applied to the iota.
4061 auto step_numerator = rewriter.create<chlo::BroadcastSubOp>(
4062 op.getLoc(), op.start().getType(), op.stop(), op.start(),
4063 hlo::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
4064 Value step_denominator = rewriter.create<ConvertOp>(
4065 op.getLoc(), op.num(), result_type.getElementType());
4066 if (num > 1) {
4067 Value one = GetScalarConstOfType(result_type.getElementType(),
4068 op.getLoc(), 1, &rewriter);
4069 step_denominator = rewriter.create<chlo::BroadcastSubOp>(
4070 op.getLoc(), step_denominator.getType(), step_denominator, one,
4071 hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
4072 }
4073 auto step = rewriter.create<chlo::BroadcastDivOp>(
4074 op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
4075 hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator,
4076 step_denominator));
4077
4078 // Scale the iota and add the offset.
4079 auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
4080 rewriter.getI64IntegerAttr(0));
4081 auto scaled = rewriter.create<chlo::BroadcastMulOp>(
4082 op.getLoc(), result_type, iota, step,
4083 hlo::getBroadcastDimensionsAttr(&rewriter, iota, step));
4084 rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
4085 op, result_type, scaled, op.start(),
4086 hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
4087 return success();
4088 }
4089 };
4090
4091 /// Converts a generic OpTy tensorflow op to a mhlo.reduce op over
4092 /// ReductionOp.
4093 /// `is_accumulation` controls whether it uses higher precision for the actual
4094 /// reduction. This is set to false for ops like max where there is no precision
4095 /// concerns.
4096 //
4097 // The Derived class should have a static method to return the initial value to
4098 // use for reduction:
4099 // static Value GetInitialValue(Type reduce_element_type, Location loc,
4100 // PatternRewriter *rewriter);
4101 // The reduce_element_type is guaranteed to be a float, int, or complex type
4102 // suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType.
4103 template <typename Derived, typename OpTy, typename ReductionOp,
4104 bool is_accumulation = true>
4105 class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
4106 using OpRewritePattern<OpTy>::OpRewritePattern;
4107
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4108 LogicalResult matchAndRewrite(OpTy op,
4109 PatternRewriter &rewriter) const override {
4110 // TODO(b/141785544): Update this to not require ranked shapes.
4111 // Input shape needs to be ranked to convert negative indices in TensorFlow
4112 // to absolute indices required by HLO.
4113 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
4114 if (!input_ty) return failure();
4115 ArrayRef<int64_t> input_shape = input_ty.getShape();
4116
4117 DenseIntElementsAttr dimensions;
4118 if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)))
4119 return failure();
4120
4121 // Build the final shape from input_shape and dimensions using a bitmap
4122 // to mark the reduced dimensions.
4123 SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false);
4124 SmallVector<int64_t, 4> xla_dimensions;
4125 for (const APInt &index_raw : dimensions.getValues<APInt>()) {
4126 int64_t index = index_raw.getSExtValue();
4127 int64_t rank = input_shape.size();
4128 if ((index < -rank || index >= rank)) return failure();
4129 index = (index + rank) % rank;
4130 reduced_dimensions_bitmap[index] = true;
4131 xla_dimensions.push_back(index);
4132 }
4133
4134 Location loc = op.getLoc();
4135 Type element_type = input_ty.getElementType();
4136
4137 // Only float, int, and complex types are currently supported.
4138 if (!element_type.isa<FloatType>() && !element_type.isa<IntegerType>() &&
4139 !element_type.isa<ComplexType>()) {
4140 return rewriter.notifyMatchFailure(
4141 op, "element type must be float, int, or complex type");
4142 }
4143
4144 // Convert to an accumulation type to not lose precision when doing
4145 // repeated arithmetic operations.
4146 Type reduce_element_type =
4147 is_accumulation ? GetAccumulationType(element_type) : element_type;
4148 auto casted_input =
4149 rewriter.create<ConvertOp>(loc, op.input(), reduce_element_type);
4150
4151 // Each reduction op can have a different initial value.
4152 Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter);
4153
4154 auto reduction = rewriter.create<ReduceOp>(
4155 loc, casted_input.getResult(), init,
4156 GetI64ElementsAttr(xla_dimensions, &rewriter));
4157 BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(),
4158 &rewriter);
4159 Value result = reduction.getResult(0);
4160
4161 // The mean op needs to divide by the product of the reduced dimensions.
4162 if (std::is_same<OpTy, TF::MeanOp>::value) {
4163 Value in_shape = rewriter.create<shape::ShapeOfOp>(loc, op.input());
4164 Value divisor_count = rewriter.create<arith::ConstantIndexOp>(loc, 1);
4165 for (size_t i = 0; i < input_shape.size(); ++i) {
4166 if (reduced_dimensions_bitmap[i]) {
4167 Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
4168 auto dim = rewriter.create<tensor::ExtractOp>(loc, in_shape, index);
4169 divisor_count =
4170 rewriter.create<arith::MulIOp>(loc, divisor_count, dim);
4171 }
4172 }
4173 // HLO ops are only defined on tensors, so we cast the divisor from
4174 // index -> i64 -> tensor<1xi64> -> tensor<i64> -> tensor<reduction type>
4175 Value divisor_casted = rewriter.create<arith::IndexCastOp>(
4176 loc, rewriter.getI64Type(), divisor_count);
4177 Value divisor_tensor = rewriter.create<tensor::FromElementsOp>(
4178 loc, RankedTensorType::get({}, rewriter.getI64Type()),
4179 divisor_casted);
4180 Value divisor = rewriter.create<ConvertOp>(
4181 loc, RankedTensorType::get({}, reduce_element_type), divisor_tensor);
4182 auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
4183 result = rewriter.create<chlo::BroadcastDivOp>(loc, result, divisor,
4184 broadcast_dims);
4185 }
4186
4187 result = rewriter.create<ConvertOp>(loc, result, element_type);
4188
4189 // Need to reshape back after the reduction if we're keeping the reduced
4190 // dimensions. Note that we do this through successive (nominally 1)
4191 // applications of the TF ExpandDims op vs a more labor intensive
4192 // reshape. Various code generation techniques benefit from the knowledge
4193 // that this is a restricted form of shape manipulation that is just adding
4194 // unit dims.
4195 if (op.keep_dims()) {
4196 for (auto &dim_is_reduced : llvm::enumerate(reduced_dimensions_bitmap)) {
4197 if (dim_is_reduced.value()) {
4198 auto index_attr = GetI32ElementsAttr(
4199 {static_cast<int>(dim_is_reduced.index())}, &rewriter);
4200 Value index = rewriter.create<arith::ConstantOp>(loc, index_attr);
4201 result = rewriter.create<TF::ExpandDimsOp>(loc, result, index);
4202 }
4203 }
4204 }
4205 rewriter.replaceOp(op, {result});
4206
4207 return success();
4208 }
4209 };
4210
4211 // Converts Mean op to HLO Reduce op.
4212 //
4213 // %init = arith.constant dense<...> : tensor<T>
4214 // %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
4215 // {dimensions = ...}
4216 // %divisor = arith.constant dense<...> : tensor<T>
4217 // %mean = "mhlo.divide"(%sum, %divisor)
4218 class ConvertMeanOp
4219 : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
4220 public:
4221 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4222 static Value GetInitialValue(Type reduce_element_type, Location loc,
4223 PatternRewriter *rewriter) {
4224 return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter);
4225 }
4226 };
4227
4228 // Converts Sum op to HLO Reduce op.
4229 //
4230 // %init = arith.constant dense<...> : tensor<T>
4231 // %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
4232 // {dimensions = ...}
4233 class ConvertSumOp
4234 : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
4235 public:
4236 using GenericConvertReductionOp::GenericConvertReductionOp;
4237
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4238 static Value GetInitialValue(Type reduce_element_type, Location loc,
4239 PatternRewriter *rewriter) {
4240 // The neutral element of fp addition is -0.0, not 0.0: '0.0 + -0.0 = 0.0'.
4241 return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter);
4242 }
4243 };
4244
4245 // Converts Max op to HLO Reduce op.
4246 //
4247 // %init = arith.constant dense<...> : tensor<T>
4248 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
4249 // {dimensions = ...}
4250 class ConvertMaxOp
4251 : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
4252 /* is_accumulation= */ false> {
4253 public:
4254 using GenericConvertReductionOp::GenericConvertReductionOp;
4255
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4256 static Value GetInitialValue(Type reduce_element_type, Location loc,
4257 PatternRewriter *rewriter) {
4258 return GetScalarLimitConstOfType(reduce_element_type, loc,
4259 hlo::kInfinityLowest, rewriter);
4260 }
4261 };
4262
4263 // Converts Min op to HLO Reduce op.
4264 //
4265 // %init = arith.constant dense<...> : tensor<T>
4266 // %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"]
4267 // {dimensions = ...}
4268 class ConvertMinOp
4269 : public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
4270 /* is_accumulation= */ false> {
4271 public:
4272 using GenericConvertReductionOp::GenericConvertReductionOp;
4273
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4274 static Value GetInitialValue(Type reduce_element_type, Location loc,
4275 PatternRewriter *rewriter) {
4276 return GetScalarLimitConstOfType(reduce_element_type, loc,
4277 hlo::kInfinityMax, rewriter);
4278 }
4279 };
4280
4281 // Converts Prod op to HLO Reduce op.
4282 //
4283 // %init = arith.constant dense<...> : tensor<T>
4284 // %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"]
4285 // {dimensions = ...}
4286 class ConvertProdOp
4287 : public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
4288 public:
4289 using GenericConvertReductionOp::GenericConvertReductionOp;
4290
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4291 static Value GetInitialValue(Type reduce_element_type, Location loc,
4292 PatternRewriter *rewriter) {
4293 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4294 }
4295 };
4296
4297 // Converts All op to HLO Reduce op.
4298 //
4299 // %init = arith.constant dense<...> : tensor<T>
4300 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"]
4301 // {dimensions = ...}
4302 class ConvertAllOp
4303 : public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
4304 public:
4305 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4306 static Value GetInitialValue(Type reduce_element_type, Location loc,
4307 PatternRewriter *rewriter) {
4308 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
4309 }
4310 };
4311
4312 // Converts Any op to HLO Reduce op.
4313 //
4314 // %init = arith.constant dense<...> : tensor<T>
4315 // %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"]
4316 // {dimensions = ...}
4317 class ConvertAnyOp
4318 : public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
4319 public:
4320 using GenericConvertReductionOp::GenericConvertReductionOp;
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)4321 static Value GetInitialValue(Type reduce_element_type, Location loc,
4322 PatternRewriter *rewriter) {
4323 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
4324 }
4325 };
4326
4327 // Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform
4328 // a reduction on the original input and the corresponding index. The reduction
4329 // sub-computation selects the max (or min) value and the index for the value.
4330 // Derived: is the resulting derived class of this class.
4331 // OpTy: is TF::ArgMaxOp or TF::ArgMinOp.
4332 template <typename Derived, typename OpTy>
4333 class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
4334 using OpRewritePattern<OpTy>::OpRewritePattern;
4335
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4336 LogicalResult matchAndRewrite(OpTy op,
4337 PatternRewriter &rewriter) const override {
4338 RankedTensorType input_type =
4339 op.input().getType().template dyn_cast<RankedTensorType>();
4340 if (!input_type) {
4341 return failure();
4342 }
4343
4344 Type input_element_type = input_type.getElementType();
4345 // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If
4346 // tf.ArgMax doesn't support complex data types, this check can be removed.
4347 if (!input_element_type.isSignlessIntOrFloat()) return failure();
4348
4349 Location loc = op.getLoc();
4350 Value init_value =
4351 Derived::GetInitialValue(input_element_type, loc, rewriter);
4352
4353 RankedTensorType output_type =
4354 op.output().getType().template dyn_cast<RankedTensorType>();
4355 if (!output_type) {
4356 return rewriter.notifyMatchFailure(op, "requires known rank");
4357 }
4358
4359 Type index_element_type = output_type.getElementType();
4360 Value index_init_value =
4361 GetScalarConstOfType(index_element_type, loc, 0, &rewriter);
4362
4363 RankedTensorType index_type =
4364 RankedTensorType::get(input_type.getShape(), index_element_type);
4365
4366 llvm::Optional<int64_t> optional_axis =
4367 GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank());
4368 if (!optional_axis.has_value())
4369 return rewriter.notifyMatchFailure(op, "required axis");
4370 int64_t axis = optional_axis.getValue();
4371
4372 IntegerAttr iota_dimension =
4373 IntegerAttr::get(rewriter.getIntegerType(64), axis);
4374 Value input_shape = rewriter.create<shape::ShapeOfOp>(loc, op.input());
4375 Value index_values = rewriter.create<DynamicIotaOp>(
4376 loc, index_type, input_shape, iota_dimension);
4377
4378 Value operands[] = {op.input(), index_values};
4379 Value init_values[] = {init_value, index_init_value};
4380 DenseIntElementsAttr reduction_dimensions =
4381 GetI64ElementsAttr({axis}, &rewriter);
4382
4383 auto reduction = rewriter.create<ReduceOp>(
4384 loc, llvm::ArrayRef<Value>(operands),
4385 llvm::ArrayRef<Value>(init_values), reduction_dimensions);
4386 auto direction = Derived::GetDirection();
4387 BuildArgMinMaxReductionBody(input_element_type, index_element_type,
4388 direction, &reduction.body(), &rewriter);
4389
4390 rewriter.replaceOp(op, {reduction.getResult(1)});
4391 return success();
4392 }
4393 };
4394
4395 // Converts tensorflow ArgMax op to mhlo operations. The actual
4396 // implementation is in class ConvertArgMinMaxOp:
4397 //
4398 // %init_index = arith.constant dense<...> : tensor<T>
4399 // %init = arith.constant dense<...> : tensor<T>
4400 // %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
4401 // %init_index) ["mhlo.arg_max"]
4402 class ConvertArgMaxOp
4403 : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
4404 public:
4405 using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
4406
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)4407 static Value GetInitialValue(Type reduce_element_type, Location loc,
4408 PatternRewriter &rewriter) {
4409 return GetScalarLimitConstOfType(reduce_element_type, loc,
4410 hlo::kInfinityLowest, &rewriter);
4411 }
4412
GetDirection()4413 static ComparisonDirection GetDirection() { return ComparisonDirection::GE; }
4414 };
4415
4416 // Converts tensorflow ArgMin op to mhlo operations. The actual
4417 // implementation is in class ConvertArgMinMaxOp:
4418 //
4419 // %init_index = arith.constant dense<...> : tensor<T>
4420 // %init = arith.constant dense<...> : tensor<T>
4421 // %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
4422 // %init_index) ["mhlo.arg_min"]
4423 class ConvertArgMinOp
4424 : public ConvertArgMinMaxOp<ConvertArgMinOp, TF::ArgMinOp> {
4425 public:
4426 using ConvertArgMinMaxOp::ConvertArgMinMaxOp;
4427
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter & rewriter)4428 static Value GetInitialValue(Type reduce_element_type, Location loc,
4429 PatternRewriter &rewriter) {
4430 return GetScalarLimitConstOfType(reduce_element_type, loc,
4431 hlo::kInfinityMax, &rewriter);
4432 }
4433
GetDirection()4434 static ComparisonDirection GetDirection() { return ComparisonDirection::LE; }
4435 };
4436
4437 // Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with
4438 // assignment:
4439 //
4440 // %result = "mhlo.scatter"(%tensor, %indices, %updates)
4441 // { dimensions = ... }
4442 //
4443 template <typename Derived, typename OpTy>
4444 class ConvertTensorScatterOp : public OpRewritePattern<OpTy> {
4445 public:
4446 using OpRewritePattern<OpTy>::OpRewritePattern;
4447
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4448 LogicalResult matchAndRewrite(OpTy op,
4449 PatternRewriter &rewriter) const override {
4450 auto tensor_ty =
4451 op.tensor().getType().template dyn_cast<RankedTensorType>();
4452 auto indices_ty =
4453 op.indices().getType().template dyn_cast<RankedTensorType>();
4454 auto updates_ty =
4455 op.updates().getType().template dyn_cast<RankedTensorType>();
4456
4457 if (!tensor_ty || !indices_ty || !updates_ty) return failure();
4458 // Last dimension of the indices needs to known at compile time for
4459 // computation of the 'update_window_dims' attribute in the dimensions
4460 // struct.
4461 int64_t num_index_dims = indices_ty.getShape().back();
4462 if (ShapedType::isDynamic(num_index_dims)) return failure();
4463
4464 auto updates = op.updates();
4465
4466 // Broadcast scalar `updates` in into expected shape as following shape:
4467 // updates.shape == indices.shape[:-1] + tensor.shape[indices.shape[-1]:]
4468 if (updates_ty.getRank() == 0 &&
4469 (std::is_same<OpTy, TF::TensorScatterUpdateOp>::value ||
4470 std::is_same<OpTy, TF::TensorScatterAddOp>::value)) {
4471 if (!tensor_ty.hasStaticShape()) {
4472 return failure();
4473 }
4474
4475 if (!indices_ty.hasStaticShape()) {
4476 return failure();
4477 }
4478
4479 auto tensor_shape = tensor_ty.getShape();
4480 auto indices_shape = indices_ty.getShape();
4481 auto index_depth = indices_shape.back();
4482 llvm::SmallVector<int64_t> expected_update_shape;
4483
4484 // create the expected update shape which scalar update is broadcasted to
4485 expected_update_shape.append(indices_shape.begin(),
4486 std::prev(indices_shape.end()));
4487
4488 expected_update_shape.append(std::next(tensor_shape.begin(), index_depth),
4489 tensor_shape.end());
4490
4491 auto const_type = RankedTensorType::get(
4492 {static_cast<int>(expected_update_shape.size())},
4493 rewriter.getIntegerType(64));
4494
4495 auto const_attr = GetI64ElementsAttr(expected_update_shape, &rewriter);
4496
4497 auto const_op =
4498 rewriter.create<TF::ConstOp>(op->getLoc(), const_type, const_attr);
4499
4500 auto broadcast_to_type = RankedTensorType::get(
4501 llvm::makeArrayRef<int64_t>(expected_update_shape),
4502 updates_ty.getElementType());
4503
4504 updates = rewriter.create<TF::BroadcastToOp>(
4505 op->getLoc(), broadcast_to_type, op.updates(), const_op);
4506
4507 updates_ty = updates.getType().template dyn_cast<RankedTensorType>();
4508 }
4509
4510 int64_t tensor_rank = tensor_ty.getRank();
4511 int64_t indices_rank = indices_ty.getRank();
4512 int64_t updates_rank =
4513 updates.getType().template dyn_cast<RankedTensorType>().getRank();
4514
4515 int64_t window_dims = tensor_rank - num_index_dims;
4516 auto dims_attr = ScatterDimensionNumbersAttr::get(
4517 rewriter.getContext(),
4518 llvm::to_vector<4>(
4519 llvm::seq<int64_t>(updates_rank - window_dims, updates_rank)),
4520 llvm::to_vector<4>(llvm::seq<int64_t>(0, num_index_dims)),
4521 llvm::to_vector<4>(llvm::seq<int64_t>(0, num_index_dims)),
4522 indices_rank - 1);
4523
4524 Location loc = op.getLoc();
4525 auto scatter = rewriter.create<ScatterOp>(loc, op.getType(),
4526 ValueRange(Value(op.tensor())),
4527 op.indices(), updates, dims_attr);
4528 Derived::BuildScatterBody(tensor_ty.getElementType(),
4529 &scatter.update_computation(), loc, rewriter);
4530
4531 rewriter.replaceOp(op, scatter.getResult(0));
4532 return success();
4533 }
4534 };
4535
4536 class ConvertTensorScatterUpdateOp
4537 : public ConvertTensorScatterOp<ConvertTensorScatterUpdateOp,
4538 TF::TensorScatterUpdateOp> {
4539 public:
4540 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4541
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4542 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4543 OpBuilder &builder) {
4544 OpBuilder::InsertionGuard guard(builder);
4545 Block *block = builder.createBlock(region);
4546 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4547 block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4548 builder.create<ReturnOp>(loc, block->getArgument(1));
4549 }
4550 };
4551
4552 class ConvertTensorScatterAddOp
4553 : public ConvertTensorScatterOp<ConvertTensorScatterAddOp,
4554 TF::TensorScatterAddOp> {
4555 public:
4556 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4557
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4558 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4559 OpBuilder &builder) {
4560 OpBuilder::InsertionGuard guard(builder);
4561 Block *block = builder.createBlock(region);
4562 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4563 block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4564 auto add_op = builder.create<AddOp>(loc, block->getArgument(0),
4565 block->getArgument(1));
4566 builder.create<ReturnOp>(loc, add_op.getResult());
4567 }
4568 };
4569
4570 class ConvertTensorScatterSubOp
4571 : public ConvertTensorScatterOp<ConvertTensorScatterSubOp,
4572 TF::TensorScatterSubOp> {
4573 public:
4574 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4575
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4576 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4577 OpBuilder &builder) {
4578 OpBuilder::InsertionGuard guard(builder);
4579 Block *block = builder.createBlock(region);
4580 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4581 block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4582 auto sub_op = builder.create<SubtractOp>(loc, block->getArgument(0),
4583 block->getArgument(1));
4584 builder.create<ReturnOp>(loc, sub_op.getResult());
4585 }
4586 };
4587
4588 class ConvertTensorScatterMinOp
4589 : public ConvertTensorScatterOp<ConvertTensorScatterMinOp,
4590 TF::TensorScatterMinOp> {
4591 public:
4592 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4593
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4594 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4595 OpBuilder &builder) {
4596 OpBuilder::InsertionGuard guard(builder);
4597 Block *block = builder.createBlock(region);
4598 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4599 block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4600 auto min_op = builder.create<MinOp>(loc, block->getArgument(0),
4601 block->getArgument(1));
4602 builder.create<ReturnOp>(loc, min_op.getResult());
4603 }
4604 };
4605
4606 class ConvertTensorScatterMaxOp
4607 : public ConvertTensorScatterOp<ConvertTensorScatterMaxOp,
4608 TF::TensorScatterMaxOp> {
4609 public:
4610 using ConvertTensorScatterOp::ConvertTensorScatterOp;
4611
BuildScatterBody(Type element_type,Region * region,Location loc,OpBuilder & builder)4612 static void BuildScatterBody(Type element_type, Region *region, Location loc,
4613 OpBuilder &builder) {
4614 OpBuilder::InsertionGuard guard(builder);
4615 Block *block = builder.createBlock(region);
4616 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4617 block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4618 auto max_op = builder.create<MaxOp>(loc, block->getArgument(0),
4619 block->getArgument(1));
4620 builder.create<ReturnOp>(loc, max_op.getResult());
4621 }
4622 };
4623
4624 // Converts Tile op to HLO BroadcastInDim and Reshape ops.
4625 // For shape [S1, S2] and multiples [M1, M2],
4626 // MS1 = M1 * S1; MS2 = M2 * S2
4627 //
4628 // %broadcast = mhlo.broadcast_in_dim(%input) {
4629 // broadcast_dimensions = [0, 2]
4630 // }
4631 // %result = "mhlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
4632 // -> tensor<MS1xMS2xf32>
4633 class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
4634 public:
4635 using OpRewritePattern::OpRewritePattern;
4636
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const4637 LogicalResult matchAndRewrite(TF::TileOp op,
4638 PatternRewriter &rewriter) const override {
4639 auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
4640 if (!input_ty || !input_ty.hasStaticShape()) return failure();
4641 ArrayRef<int64_t> input_shape = input_ty.getShape();
4642 Type element_type = input_ty.getElementType();
4643
4644 DenseIntElementsAttr multiples;
4645 if (!matchPattern(op.multiples(), m_Constant(&multiples)) ||
4646 multiples.getType().getRank() != 1)
4647 return failure();
4648
4649 const int64_t input_shape_size = input_shape.size();
4650 if (multiples.getNumElements() != input_shape_size) return failure();
4651
4652 SmallVector<int64_t, 8> broadcasted_shape;
4653 SmallVector<int64_t, 4> broadcast_dimensions;
4654 broadcasted_shape.reserve(input_shape.size() * 2);
4655 broadcast_dimensions.reserve(input_shape.size());
4656 for (auto multiple_and_input :
4657 llvm::zip(multiples.getValues<APInt>(), input_shape)) {
4658 int64_t multiple = std::get<0>(multiple_and_input).getSExtValue();
4659 int64_t input_size = std::get<1>(multiple_and_input);
4660
4661 if (multiple < 0) return failure();
4662
4663 // Line input up with the next dimension in broadcasted_shape
4664 // when broadcasting.
4665 int64_t broadcast_dim;
4666 int64_t output_size = input_size * multiple;
4667 if (input_size == 1 || multiple == 1) {
4668 // Special case for when normal broadcasting will just work.
4669 broadcast_dim = broadcasted_shape.size();
4670 broadcasted_shape.push_back(output_size);
4671 } else {
4672 // Tiling will happen for this dimension during the ReshapeOp below.
4673 broadcasted_shape.push_back(multiple);
4674 broadcast_dim = broadcasted_shape.size();
4675 broadcasted_shape.push_back(input_size);
4676 }
4677 broadcast_dimensions.push_back(broadcast_dim);
4678 }
4679 Location loc = op.getLoc();
4680 Type broadcasted_type =
4681 RankedTensorType::get(broadcasted_shape, element_type);
4682 Type output_type = op.getType();
4683
4684 Value result = rewriter.create<BroadcastInDimOp>(
4685 loc, broadcasted_type, op.input(),
4686 GetI64ElementsAttr(broadcast_dimensions, &rewriter));
4687
4688 if (output_type != broadcasted_type) {
4689 result = rewriter.create<ReshapeOp>(loc, output_type, result);
4690 }
4691
4692 rewriter.replaceOp(op, {result});
4693
4694 return success();
4695 }
4696 };
4697
4698 // Converts the tf.TileOp op into mhlo.dynamic_reshape
4699 // TODO(disc): To recover static special case's performance with folding and
4700 // canonicalization.
4701 class ConvertTileOpDynamic : public OpRewritePattern<TF::TileOp> {
4702 public:
4703 using OpRewritePattern::OpRewritePattern;
4704 // clang-format off
4705 // Converts Tile op to HLO DBroadcastInDim and DReshape ops.
4706 // For shape [S1, S2] and multiples [M1, M2],
4707 // MS1 = M1 * S1; MS2 = M2 * S2
4708 //
4709 // %out_dim_size = [S1, M1, S2, M2]
4710 // %broadcast_dimensions = [1, 3];
4711 // %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions);
4712 // %shape = [MS1, MS2]
4713 // %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor<S1xM1xS2xM2xf32>) -> tensor<MS1xMS2xf32>
4714 // clang-format on
matchAndRewrite(TF::TileOp op,PatternRewriter & rewriter) const4715 LogicalResult matchAndRewrite(TF::TileOp op,
4716 PatternRewriter &rewriter) const final {
4717 Location loc = op.getLoc();
4718 Value input = op.input();
4719 Value multiples = op.multiples();
4720 auto input_ty = input.getType().dyn_cast<RankedTensorType>();
4721 if (!input_ty) return failure();
4722 // TODO(disc): Remove this constraint once fold and canonicalization
4723 // implemented.
4724 if (input_ty.hasStaticShape()) return failure();
4725
4726 Type element_type = input_ty.getElementType();
4727 int64_t input_rank = input_ty.getRank();
4728 SmallVector<Value, 4> input_shape_values;
4729 for (int64_t i = 0; i < input_rank; ++i) {
4730 auto dim_size = input_ty.getDimSize(i);
4731 if (dim_size == ShapedType::kDynamicSize) {
4732 input_shape_values.push_back(
4733 rewriter.create<tensor::DimOp>(loc, input, i));
4734 } else {
4735 input_shape_values.push_back(rewriter.create<arith::ConstantOp>(
4736 loc, rewriter.getIndexAttr(dim_size)));
4737 }
4738 }
4739
4740 auto multiples_ty = multiples.getType().dyn_cast<RankedTensorType>();
4741 int64_t multiples_rank = multiples_ty.getRank();
4742 // rank of multiples input of tf.TileOp must be 1
4743 if (multiples_rank != 1) return failure();
4744 // multiples input of tf.TileOp must be fixed shaped
4745 if ((!multiples_ty.hasStaticShape()) ||
4746 (multiples_ty.getDimSize(0) != input_rank)) {
4747 return failure();
4748 }
4749 Type index_ty = rewriter.getIndexType();
4750 // %out_dim_size
4751 SmallVector<Value, 4> out_dim_size;
4752 out_dim_size.reserve(input_rank * 2);
4753 for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) {
4754 Value index = rewriter.create<arith::ConstantOp>(
4755 loc, rewriter.getIndexAttr(dim_idx));
4756 Value multiples_size =
4757 rewriter.create<tensor::ExtractOp>(loc, multiples, ValueRange{index});
4758 Value multiples_size_casted =
4759 rewriter.create<arith::IndexCastOp>(loc, index_ty, multiples_size);
4760 out_dim_size.push_back(multiples_size_casted);
4761 out_dim_size.push_back(input_shape_values[dim_idx]);
4762 }
4763 SmallVector<int64_t, 4> broadcast_dimensions;
4764 broadcast_dimensions.reserve(input_rank);
4765 for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) {
4766 broadcast_dimensions.push_back(1 + 2 * dim_idx);
4767 }
4768 auto broadcast_dims_attr =
4769 GetI64ElementsAttr(broadcast_dimensions, &rewriter);
4770
4771 Value out_dim_size_tensor = rewriter.create<tensor::FromElementsOp>(
4772 loc,
4773 RankedTensorType::get({static_cast<int64_t>(out_dim_size.size())},
4774 index_ty),
4775 out_dim_size);
4776 SmallVector<int64_t, 4> broadcast_shape(input_rank * 2,
4777 ShapedType::kDynamicSize);
4778 RankedTensorType broadcast_type =
4779 RankedTensorType::get(broadcast_shape, element_type);
4780 Value broadcast = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
4781 loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr);
4782
4783 // %shape = [MS1, MS2]
4784 SmallVector<Value, 4> shape_values;
4785 shape_values.reserve(input_rank);
4786 for (int64_t i = 0; i < input_rank; ++i) {
4787 Value dim_size_value = rewriter.create<mlir::arith::MulIOp>(
4788 loc, out_dim_size[2 * i], out_dim_size[2 * i + 1]);
4789 shape_values.push_back(dim_size_value);
4790 }
4791 Value shape = rewriter.create<tensor::FromElementsOp>(
4792 loc, RankedTensorType::get({input_rank}, index_ty), shape_values);
4793 rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, op.getType(),
4794 broadcast, shape);
4795 return success();
4796 }
4797 };
4798
4799 template <typename OpTy, int num_dims>
4800 class ConvertMaxPoolGradOp : public OpRewritePattern<OpTy> {
4801 public:
4802 using OpRewritePattern<OpTy>::OpRewritePattern;
4803
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4804 LogicalResult matchAndRewrite(OpTy op,
4805 PatternRewriter &rewriter) const override {
4806 Location loc = op.getLoc();
4807
4808 Type element_type =
4809 op.orig_input().getType().template cast<TensorType>().getElementType();
4810
4811 // Compute paddings using the original input and kernel shape and strides.
4812 // Here, ReduceWindow op as used as the MaxPool op is lowered to the
4813 // ReduceWindow op.
4814 auto input_ty =
4815 op.orig_input().getType().template dyn_cast<RankedTensorType>();
4816 if (!input_ty) return failure();
4817 DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr<num_dims>(
4818 input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
4819
4820 auto result = rewriter.create<SelectAndScatterOp>(
4821 loc, op.getType(), op.orig_input(), op.grad(),
4822 GetScalarConstOfType(element_type, loc, 0, &rewriter),
4823 GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()),
4824 paddings_attr);
4825
4826 BuildReduceBody<AddOp>(element_type, &result.scatter(), &rewriter);
4827 {
4828 OpBuilder::InsertionGuard guard(rewriter);
4829 Block *block = rewriter.createBlock(&result.select());
4830
4831 // Block arguments are scalars of the given element type.
4832 Type type = RankedTensorType::get(/*shape=*/{}, element_type);
4833 block->addArguments({type, type}, SmallVector<Location, 2>(2, loc));
4834
4835 auto reducer = rewriter.create<CompareOp>(loc, block->getArgument(0),
4836 block->getArgument(1),
4837 ComparisonDirection::GE);
4838 rewriter.create<ReturnOp>(loc, reducer.getResult());
4839 }
4840
4841 rewriter.replaceOp(op, {result});
4842
4843 return success();
4844 }
4845 };
4846
4847 using ConvertMaxPool2DGradOp =
4848 ConvertMaxPoolGradOp<TF::MaxPoolGradOp, /*num_dims=*/4>;
4849 using ConvertMaxPool3DGradOp =
4850 ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
4851
4852 // Converts tf.Conv?DBackpropInputOp into:
4853 // %rev_filter = "mhlo.reverse"(%filter)
4854 // %result = "mhlo.convolution"(%out_backprop, %rev_filter)
4855 template <typename OpTy, int num_spatial_dims>
4856 class ConvertConvBackpropInputOp : public OpRewritePattern<OpTy> {
4857 public:
4858 using OpRewritePattern<OpTy>::OpRewritePattern;
4859
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const4860 LogicalResult matchAndRewrite(OpTy op,
4861 PatternRewriter &rewriter) const override {
4862 // Unpack all of the attributes.
4863 tensorflow::TensorFormat data_format;
4864 if (!FormatFromString(op.data_format().str(), &data_format))
4865 return op.emitOpError("invalid data format");
4866 constexpr int num_dims = num_spatial_dims + 2;
4867 int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
4868
4869 tensorflow::Padding padding;
4870 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
4871 return failure();
4872
4873 auto out_backprop_ty =
4874 op.out_backprop().getType().template dyn_cast<RankedTensorType>();
4875 auto filter_ty =
4876 op.filter().getType().template dyn_cast<RankedTensorType>();
4877
4878 // With the exception of out_backprop's batch dimension, out_backprop and
4879 // filter need to have static shape. Filter is validated here, out_backprop
4880 // is mostly validated at use.
4881 if (!out_backprop_ty || !filter_ty || !filter_ty.hasStaticShape())
4882 return failure();
4883
4884 // Compute input_shape by supporting either:
4885 // 1) Fully static shapes, represented as constants.
4886 // 2) Static shapes with a dynamic batch dimension, represented as
4887 // 1D tf.Pack of a batch dimension (can be static or dynamic)
4888 // and other dimensions (can only be static), for example:
4889 // "tf.Pack"(%142, %cst_301, %cst_301, %cst_300) {axis = 0 : i64, ...}
4890 std::vector<int64_t> input_shape;
4891 DenseIntElementsAttr input_shape_attr;
4892 if (matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) &&
4893 input_shape_attr.getType().getRank() == 1) {
4894 input_shape.insert(input_shape.end(),
4895 input_shape_attr.getValues<int32_t>().begin(),
4896 input_shape_attr.getValues<int32_t>().end());
4897 } else {
4898 auto pack = op.input_sizes().template getDefiningOp<TF::PackOp>();
4899 if (!pack || pack.axis() != 0) return failure();
4900 auto pack_ty = pack.getType().template dyn_cast<RankedTensorType>();
4901 if (!pack_ty || pack_ty.getRank() != 1) return failure();
4902 for (auto i = 0; i < pack_ty.getDimSize(0); ++i) {
4903 if (i == batch_dim) {
4904 // We don't use the batch dimension below, so we don't care about
4905 // its size. Might as well populate it with -1.
4906 input_shape.push_back(ShapedType::kDynamicSize);
4907 } else {
4908 DenseIntElementsAttr input_dims_attr;
4909 if (matchPattern(pack.values()[i], m_Constant(&input_dims_attr)) &&
4910 input_dims_attr.getType().getRank() == 0) {
4911 input_shape.push_back(input_dims_attr.getSplatValue<int32_t>());
4912 } else {
4913 return failure();
4914 }
4915 }
4916 }
4917 }
4918
4919 auto dilations_attr = GetI64ElementsAttr(op.dilations());
4920 std::vector<int> dilations{
4921 dilations_attr.template getValues<int64_t>().begin(),
4922 dilations_attr.template getValues<int64_t>().end()};
4923 auto strides_attr = GetI64ElementsAttr(op.strides());
4924 std::vector<tensorflow::int32> strides{
4925 strides_attr.template getValues<int64_t>().begin(),
4926 strides_attr.template getValues<int64_t>().end()};
4927
4928 std::vector<int64_t> explicit_paddings;
4929 if (padding == tensorflow::Padding::EXPLICIT) {
4930 // EXPLICIT padding mode and the associated attribute is limited to
4931 // Conv2DBackpropInput. So, fetch attribute by identifier instead of the
4932 // op.explicit_paddings() attribute getter.
4933 ArrayRef<Attribute> explicit_paddings_attr =
4934 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
4935 explicit_paddings.reserve(explicit_paddings_attr.size());
4936 for (Attribute explicit_padding : explicit_paddings_attr)
4937 explicit_paddings.push_back(
4938 explicit_padding.cast<IntegerAttr>().getInt());
4939 }
4940
4941 ArrayRef<int64_t> filter_shape = filter_ty.getShape();
4942
4943 // Compute ConvDimensionNumbers, dilation, and padding.
4944 SmallVector<int64_t, num_spatial_dims> spatial_dims;
4945 SmallVector<int64_t, num_spatial_dims> lhs_dilation;
4946 SmallVector<int64_t, num_spatial_dims> rhs_dilation;
4947 SmallVector<int64_t, num_spatial_dims * 2> paddings;
4948
4949 for (int i : llvm::seq<int>(0, num_spatial_dims)) {
4950 const int64_t spatial_dim =
4951 GetTensorSpatialDimIndex(num_dims, data_format, i);
4952 spatial_dims.push_back(spatial_dim);
4953
4954 // Prepare metadata indexed by spatial_dim for computing pad_before
4955 // and pad_after.
4956 int64_t input_size = input_shape[spatial_dim];
4957 if (input_size == ShapedType::kDynamicSize) return failure();
4958 int64_t output_size = out_backprop_ty.getDimSize(spatial_dim);
4959 if (output_size == ShapedType::kDynamicSize) return failure();
4960 int64_t filter_size = filter_ty.getDimSize(i);
4961 int64_t stride = strides[spatial_dim];
4962 int64_t dilation = dilations[spatial_dim];
4963
4964 // Compute pad_before and pad_after following the logic from
4965 // ConvBackpropComputeDimensionsV2. (Unfortunately, we cannot call
4966 // the function in question because it doesn't work with dynamic dims).
4967 int64_t padding_before = -1, padding_after = -1;
4968 if (padding == tensorflow::Padding::EXPLICIT) {
4969 padding_before = explicit_paddings[2 * spatial_dim];
4970 padding_after = explicit_paddings[2 * spatial_dim + 1];
4971 }
4972 int64_t expected_output_size = 0;
4973 auto status = GetWindowedOutputSizeVerboseV2(
4974 input_size, filter_size, dilation, stride, padding,
4975 &expected_output_size, &padding_before, &padding_after);
4976 if (!status.ok()) return failure();
4977 if (output_size != expected_output_size) return failure();
4978 int64_t effective_filter_size = (filter_size - 1) * dilation + 1;
4979 int64_t pad_before = effective_filter_size - 1 - padding_before;
4980 int64_t padded_out_size = input_size + effective_filter_size - 1;
4981 int64_t expanded_output_size = (output_size - 1) * stride + 1;
4982 int64_t pad_after = padded_out_size - expanded_output_size - pad_before;
4983
4984 // Populate metadata for the upcoming mhlo.conv op using the result of
4985 // the computations performed above.
4986 lhs_dilation.push_back(stride);
4987 rhs_dilation.push_back(dilation);
4988 paddings.push_back(pad_before);
4989 paddings.push_back(pad_after);
4990 }
4991
4992 RankedTensorType paddings_ty = RankedTensorType::get(
4993 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
4994 auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
4995
4996 Value filter = op.filter();
4997
4998 const int feature_dim =
4999 tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
5000 const int64_t in_depth = *(input_shape.begin() + feature_dim);
5001 if (in_depth == ShapedType::kDynamicSize) return failure();
5002 const int64_t filter_in_depth = filter_shape[num_spatial_dims];
5003 const int64_t feature_group_count = in_depth / filter_in_depth;
5004
5005 if (feature_group_count != 1) {
5006 // 1. Reshape filter from
5007 // [H, W, ..., filter_in_depth, out_depth] to
5008 // [H, W, ..., filter_in_depth, G, out_depth / G].
5009 auto new_shape = llvm::to_vector<6>(filter_shape);
5010 new_shape.back() = feature_group_count;
5011 new_shape.push_back(filter_shape.back() / feature_group_count);
5012 Type filter_element_ty = filter_ty.getElementType();
5013 auto ty = RankedTensorType::get(new_shape, filter_element_ty);
5014 filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
5015
5016 // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G].
5017 llvm::SmallVector<int64_t, 6> perm(num_dims + 1);
5018 std::iota(perm.begin(), perm.end(), 0);
5019 std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]);
5020 std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]);
5021 ty = RankedTensorType::get(new_shape, filter_element_ty);
5022 filter = rewriter.create<TransposeOp>(
5023 op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter));
5024
5025 // 3. Reshape to [H, W, ..., in_depth, out_depth / G].
5026 new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1];
5027 new_shape[num_spatial_dims + 1] = new_shape.back();
5028 new_shape.pop_back();
5029 ty = RankedTensorType::get(new_shape, filter_element_ty);
5030 filter = rewriter.create<ReshapeOp>(op.getLoc(), ty, filter);
5031 }
5032
5033 SmallVector<int64_t, 4> kernel_spatial_dims;
5034 kernel_spatial_dims.resize(num_spatial_dims);
5035 std::iota(kernel_spatial_dims.begin(), kernel_spatial_dims.end(), 0);
5036
5037 // Mirror the filter in the spatial dimensions.
5038 filter = rewriter.create<ReverseOp>(
5039 op.getLoc(), filter,
5040 GetI64ElementsAttr(kernel_spatial_dims, &rewriter));
5041
5042 // activation gradients
5043 // = gradients (with padding and dilation) <conv> mirrored_weights
5044 Value result = rewriter.create<ConvolutionOp>(
5045 op.getLoc(), op.getType(), op.out_backprop(), filter,
5046 /*window_strides=*/
5047 GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
5048 &rewriter),
5049 /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
5050 GetI64ElementsAttr(rhs_dilation, &rewriter),
5051 /*window_reversal=*/nullptr,
5052 ConvDimensionNumbersAttr::get(
5053 rewriter.getContext(),
5054 /*input_batch_dimension=*/batch_dim,
5055 /*input_feature_dimension=*/feature_dim,
5056 /*input_spatial_dimensions=*/spatial_dims,
5057 // TF filter shape is [ H, W, ..., inC, outC ]
5058 // Transpose the input and output features for computing the
5059 // gradient.
5060 /*kernel_input_feature_dimension=*/
5061 num_spatial_dims + 1,
5062 /*kernel_output_feature_dimension=*/
5063 num_spatial_dims,
5064 /*kernel_spatial_dimensions=*/kernel_spatial_dims,
5065 /*output_batch_dimension=*/batch_dim,
5066 /*output_feature_dimension=*/feature_dim,
5067 /*output_spatial_dimensions=*/spatial_dims),
5068 rewriter.getI64IntegerAttr(feature_group_count),
5069 /*batch_group_count=*/rewriter.getI64IntegerAttr(1),
5070 /*precision_config=*/ArrayAttr());
5071
5072 rewriter.replaceOp(op, {result});
5073
5074 return success();
5075 }
5076 };
5077
5078 using ConvertConv2DBackpropInputOp =
5079 ConvertConvBackpropInputOp<TF::Conv2DBackpropInputOp,
5080 /*num_spatial_dims=*/2>;
5081 using ConvertConv3DBackpropInputOp =
5082 ConvertConvBackpropInputOp<TF::Conv3DBackpropInputV2Op,
5083 /*num_spatial_dims=*/3>;
5084
5085 // Converts tf.Conv?DBackpropFilterOp into:
5086 // %result = "mhlo.convolution"(%input, %out_backprop)
5087 template <typename OpTy, int num_spatial_dims>
5088 class ConvertConvBackpropFilterOp : public OpRewritePattern<OpTy> {
5089 public:
5090 using OpRewritePattern<OpTy>::OpRewritePattern;
5091
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5092 LogicalResult matchAndRewrite(OpTy op,
5093 PatternRewriter &rewriter) const override {
5094 // Unpack all of the attributes.
5095 tensorflow::TensorFormat data_format;
5096 if (!FormatFromString(op.data_format().str(), &data_format))
5097 return op.emitOpError("invalid data format");
5098
5099 tensorflow::Padding padding;
5100 if (!GetPaddingFromString(op.padding().str(), &padding).ok())
5101 return failure();
5102
5103 auto out_backprop_ty =
5104 op.out_backprop().getType().template dyn_cast<RankedTensorType>();
5105 auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
5106
5107 for (RankedTensorType ty : {out_backprop_ty, input_ty})
5108 if (!ty || !ty.hasStaticShape()) return failure();
5109
5110 ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
5111 ArrayRef<int64_t> input_shape = input_ty.getShape();
5112
5113 DenseIntElementsAttr filter_shape_attr;
5114 if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) ||
5115 filter_shape_attr.getType().getRank() != 1)
5116 return failure();
5117
5118 auto dilations_attr = GetI64ElementsAttr(op.dilations());
5119 std::vector<int> dilations{
5120 dilations_attr.template getValues<int64_t>().begin(),
5121 dilations_attr.template getValues<int64_t>().end()};
5122 auto strides_attr = GetI64ElementsAttr(op.strides());
5123 std::vector<tensorflow::int32> strides{
5124 strides_attr.template getValues<int64_t>().begin(),
5125 strides_attr.template getValues<int64_t>().end()};
5126
5127 std::vector<int64_t> explicit_paddings;
5128 if (padding == tensorflow::Padding::EXPLICIT) {
5129 // EXPLICIT padding mode and the associated attribute is limited to
5130 // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the
5131 // op.explicit_paddings() attribute getter.
5132 ArrayRef<Attribute> explicit_paddings_attr =
5133 op->template getAttrOfType<ArrayAttr>("explicit_paddings").getValue();
5134 explicit_paddings.reserve(explicit_paddings_attr.size());
5135 for (Attribute explicit_padding : explicit_paddings_attr)
5136 explicit_paddings.push_back(
5137 explicit_padding.cast<IntegerAttr>().getInt());
5138 }
5139
5140 constexpr int num_dims = num_spatial_dims + 2;
5141 auto filter_shape = filter_shape_attr.getValues<int32_t>();
5142
5143 // Reuse dimension computation logic from conv_grad_shape_utils.cc.
5144 tensorflow::ConvBackpropDimensions dims;
5145 if (!tensorflow::ConvBackpropComputeDimensionsV2(
5146 /*label=*/"", num_spatial_dims,
5147 ToTensorShape<int64_t, num_dims>(input_shape),
5148 ToTensorShape<int32_t, num_dims>(filter_shape),
5149 ToTensorShape<int64_t, num_dims>(out_backprop_shape), dilations,
5150 strides, padding, explicit_paddings, data_format, &dims)
5151 .ok()) {
5152 return failure();
5153 }
5154
5155 // The activations (inputs) form the LHS of the convolution.
5156 // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
5157 // For the gradient computation, we need to:
5158 // 1. In the case of group convolution, move the num_groups dimension before
5159 // the batch dimension
5160 // 2. Swap the roles of the batch and feature dimensions.
5161 const int feature_dim =
5162 tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);
5163 const int64_t in_depth = input_shape[feature_dim];
5164 const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims);
5165 const int64_t batch_group_count = in_depth / filter_in_depth;
5166
5167 // Compute ConvDimensionNumbers, dilation, and padding.
5168 SmallVector<int64_t, num_spatial_dims> spatial_dims;
5169 SmallVector<int64_t, num_spatial_dims> kernel_spatial_dims;
5170 SmallVector<int64_t, num_spatial_dims> rhs_dilation;
5171 SmallVector<int64_t, num_spatial_dims * 2> paddings;
5172 SmallVector<int64_t, num_spatial_dims> window_strides;
5173
5174 // The filter gradients are computed by a convolution of the input
5175 // activations and the output gradients, with some appropriate padding.
5176 // See the comment at the top of conv_grad_ops.h for details.
5177
5178 for (int i : llvm::seq<int>(0, num_spatial_dims)) {
5179 const int64_t dim =
5180 tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i);
5181 kernel_spatial_dims.push_back(dim);
5182 // Besides padding the input, we will also expand output_rows to
5183 // expanded_out_rows = (output_rows - 1) * stride + 1
5184 // with zeros in between:
5185 //
5186 // a . . . b . . . c . . . d . . . e
5187 //
5188 // This is done by specifying the window dilation factors in the
5189 // convolution HLO below.
5190 const auto &spatial_dim_i = dims.spatial_dims[i];
5191 rhs_dilation.push_back(spatial_dim_i.stride);
5192 window_strides.push_back(dilations[dim]);
5193
5194 // We will also need to pad the input with zeros such that after the
5195 // convolution, we get the right size for the filter.
5196 // The padded_in_rows should be such that when we convolve this with the
5197 // expanded_out_rows as a filter, we should get filter_rows back.
5198
5199 const int64_t padded_in_size =
5200 spatial_dim_i.expanded_output_size +
5201 (spatial_dim_i.filter_size - 1) * dilations[dim];
5202
5203 // However it can be smaller than input_rows: in this
5204 // case it means some of the inputs are not used.
5205 //
5206 // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
5207 //
5208 // INPUT = [ A B C ]
5209 //
5210 // FILTER = [ x y ]
5211 //
5212 // and the output will only have one column: a = A * x + B * y
5213 //
5214 // and input "C" is not used at all.
5215 //
5216 // We apply negative padding in this case.
5217 const int64_t pad_total = padded_in_size - spatial_dim_i.input_size;
5218
5219 // + For the EXPLICIT padding, we pad the top/left side with the explicit
5220 // padding and pad the bottom/right side with the remaining space.
5221 // + For the VALID padding, we don't pad anything on the top/left side
5222 // and pad the bottom/right side with the remaining space.
5223 // + For the SAME padding, we pad top/left side the same as bottom/right
5224 // side.
5225 //
5226 // In addition, if the padded input size is smaller than the input size,
5227 // we need to ignore some training elements of the input. We do this by
5228 // applying negative padding on the right/bottom.
5229 const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT
5230 ? explicit_paddings[2 * dim]
5231 : padding == tensorflow::Padding::SAME
5232 ? std::max<int64_t>(pad_total / 2, 0)
5233 : 0;
5234 paddings.push_back(pad_before);
5235 paddings.push_back(pad_total - pad_before);
5236 }
5237
5238 RankedTensorType paddings_ty = RankedTensorType::get(
5239 {num_spatial_dims, 2}, rewriter.getIntegerType(64));
5240 auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings);
5241
5242 SmallVector<int64_t, 4> output_spatial_dimensions;
5243 output_spatial_dimensions.resize(num_spatial_dims);
5244 std::iota(output_spatial_dimensions.begin(),
5245 output_spatial_dimensions.end(), 0);
5246
5247 const int batch_dim =
5248 tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
5249
5250 Value result = rewriter.create<ConvolutionOp>(
5251 op.getLoc(), op.getType(), op.input(), op.out_backprop(),
5252 /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter),
5253 /*padding=*/paddings_attr, /*lhs_dilation=*/
5254 GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
5255 &rewriter),
5256 GetI64ElementsAttr(rhs_dilation, &rewriter),
5257 /*window_reversal=*/nullptr,
5258 ConvDimensionNumbersAttr::get(
5259 rewriter.getContext(),
5260 // Swap batch_dim and feature_dim in the activations.
5261 /*input_batch_dimension=*/feature_dim,
5262 /*input_feature_dimension=*/batch_dim,
5263 /*input_spatial_dimensions=*/kernel_spatial_dims,
5264 // The gradients become the RHS of the convolution.
5265 // The gradients have shape [batch, out_rows, out_cols, ...,
5266 // out_depth] where the batch becomes the input feature for the
5267 // convolution.
5268 /*kernel_input_feature_dimension=*/batch_dim,
5269 /*kernel_output_feature_dimension=*/feature_dim,
5270 /*kernel_spatial_dimensions=*/kernel_spatial_dims,
5271 /*output_batch_dimension=*/num_spatial_dims,
5272 /*output_feature_dimension=*/num_spatial_dims + 1,
5273 /*output_spatial_dimensions=*/output_spatial_dimensions),
5274 /*feature_group_count=*/rewriter.getI64IntegerAttr(1),
5275 rewriter.getI64IntegerAttr(batch_group_count),
5276 /*precision_config=*/ArrayAttr());
5277
5278 rewriter.replaceOp(op, {result});
5279
5280 return success();
5281 }
5282 };
5283
5284 using ConvertConv2DBackpropFilterOp =
5285 ConvertConvBackpropFilterOp<TF::Conv2DBackpropFilterOp,
5286 /*num_spatial_dims=*/2>;
5287 using ConvertConv3DBackpropFilterOp =
5288 ConvertConvBackpropFilterOp<TF::Conv3DBackpropFilterV2Op,
5289 /*num_spatial_dims=*/3>;
5290
5291 class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
5292 public:
5293 using OpRewritePattern::OpRewritePattern;
5294
matchAndRewrite(TF::OneHotOp op,PatternRewriter & rewriter) const5295 LogicalResult matchAndRewrite(TF::OneHotOp op,
5296 PatternRewriter &rewriter) const override {
5297 auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
5298 if (!indices_ty || !indices_ty.hasStaticShape()) return failure();
5299 ArrayRef<int64_t> indices_shape = indices_ty.getShape();
5300 Type element_type = indices_ty.getElementType();
5301
5302 DenseIntElementsAttr depth_attr;
5303 if (!matchPattern(op.depth(), m_Constant(&depth_attr))) {
5304 return failure();
5305 }
5306
5307 int64_t depth = depth_attr.getValues<APInt>()[0].getSExtValue();
5308 int64_t axis = op.axis();
5309 if (axis == -1) axis = indices_shape.size();
5310
5311 llvm::SmallVector<int64_t, 4> broadcast_dims(indices_shape.size());
5312 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
5313 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
5314
5315 llvm::SmallVector<int64_t, 4> output_dims =
5316 llvm::to_vector<4>(indices_shape);
5317 output_dims.insert(output_dims.begin() + axis, depth);
5318
5319 Location loc = op.getLoc();
5320
5321 // The iota result is the effective output shape of the computation,
5322 // and indices must be broadcast into it. At this point, this computation
5323 // would need to be reworked quite a bit to support dynamic shapes, so
5324 // just using static broadcasting.
5325 auto index_type = RankedTensorType::get(output_dims, element_type);
5326 auto iota = rewriter.create<IotaOp>(
5327 loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis));
5328 auto broadcast_indices = rewriter.create<BroadcastInDimOp>(
5329 loc, index_type, op.indices(),
5330 GetI64ElementsAttr(broadcast_dims, &rewriter));
5331
5332 Value compare = rewriter.create<mhlo::CompareOp>(
5333 loc, broadcast_indices, iota, ComparisonDirection::EQ);
5334 Value on_value = rewriter.create<BroadcastOp>(
5335 loc, op.getType(), op.on_value(),
5336 GetI64ElementsAttr(output_dims, &rewriter));
5337 Value off_value = rewriter.create<BroadcastOp>(
5338 loc, op.getType(), op.off_value(),
5339 GetI64ElementsAttr(output_dims, &rewriter));
5340 Value result = rewriter.create<SelectOp>(loc, op.getType(), compare,
5341 on_value, off_value);
5342
5343 rewriter.replaceOp(op, {result});
5344
5345 return success();
5346 }
5347 };
5348
5349 // Converts InfeedDequeueTuple to XLA HLO create_token, infeed and
5350 // get_tuple_element ops.
5351 //
5352 // All HLO infeed ops expect a HLO token type operand and produce a tuple
5353 // containing a token. This HLO token type is used to order multiple infeed
5354 // operations within a computation. The token type can come from other
5355 // infeed/outfeed/send/recv ops or can be generated using create_token op with
5356 // no operands. Here we emit a create_token op to generate the token type
5357 // operand of infeed. The mhlo.InfeedOp can produce multiple results and later
5358 // will be exported to XLA infeed op with single tuple return type.
5359 //
5360 // For example the following IR:
5361 // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>)
5362 //
5363 // would be lowered to
5364 //
5365 // %token = "mhlo.create_token"() : () -> !mhlo.token
5366 // %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} :
5367 // (!mhlo.token) -> tensor<3xi32>, tensor<4xf32>, !mhlo.token>
5368 //
5369 class ConvertInfeedDequeueTupleOp
5370 : public OpRewritePattern<TF::InfeedDequeueTupleOp> {
5371 public:
5372 using OpRewritePattern::OpRewritePattern;
5373
matchAndRewrite(TF::InfeedDequeueTupleOp op,PatternRewriter & rewriter) const5374 LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op,
5375 PatternRewriter &rewriter) const override {
5376 SmallVector<Type> result_types;
5377 result_types.reserve(op.outputs().size() + 1);
5378 for (const auto &output : op.outputs()) {
5379 Type ty = output.getType();
5380 if (auto tensor_ty = ty.dyn_cast<RankedTensorType>()) {
5381 if (!tensor_ty.hasStaticShape()) return failure();
5382 }
5383 result_types.push_back(ty);
5384 }
5385
5386 // Infeed takes a single token operand. Generate the token using
5387 // create_token op to pass to the infeed op.
5388 auto token = rewriter.create<CreateTokenOp>(
5389 op.getLoc(), mhlo::TokenType::get(rewriter.getContext()));
5390
5391 result_types.push_back(token.getType());
5392
5393 ArrayAttr layout; // filled in during the xla-adjust-layout pass
5394 auto data_and_token =
5395 rewriter.create<InfeedOp>(op.getLoc(), result_types, token,
5396 /*infeed_config=*/rewriter.getStringAttr(""),
5397 /*layout=*/layout);
5398
5399 result_types.pop_back(); // remove the token type.
5400
5401 if (op._XlaSharding().has_value()) {
5402 // _XlaSharding attribute in TF is a serialized string of the OpSharding
5403 // proto, so convert to a text form here.
5404 ::xla::OpSharding sharding_proto;
5405 if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str()))
5406 return failure();
5407
5408 // Token is a control signal and not a real data, so arbitrarily assign
5409 // the token to device 0.
5410 if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
5411 *sharding_proto.add_tuple_shardings() =
5412 ::xla::sharding_builder::AssignDevice(0);
5413 data_and_token->setAttr(
5414 kShardingAttr,
5415 rewriter.getStringAttr(sharding_proto.SerializeAsString()));
5416 } else {
5417 data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr());
5418 }
5419 }
5420
5421 if (op->hasAttr("layouts")) {
5422 // Append a UnitAttr for the "token" operand of the mhlo.infeed op here to
5423 // avoid compilation failure when exporting "layouts" attribute of the
5424 // corresponding InfeedDequeueTupleOp to a graph node.
5425 data_and_token->setAttr("layout", op->getAttr("layouts"));
5426 }
5427 llvm::SmallVector<Value> results;
5428 results.reserve(result_types.size());
5429 for (auto idx_and_type : llvm::enumerate(result_types)) {
5430 results.push_back(data_and_token.getResult(idx_and_type.index()));
5431 }
5432 rewriter.replaceOp(op, ValueRange(results));
5433 return success();
5434 }
5435 };
5436
5437 // Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed
5438 // ops.
5439 //
5440 // XLA HLO outfeed op expects a token, which we generate by emitting an
5441 // create_token op.
5442 //
5443 // For example the following IR:
5444 // "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
5445 // ()
5446 //
5447 // would be lowered to
5448 //
5449 // %token = "mhlo.create_token"() : () -> !mhlo.token
5450 // %outfeed_token = "mhlo.outfeed"(%val_1, %val_2, %token) {outfeed_config = ""}
5451 // :
5452 // (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token
5453 //
5454 class ConvertOutfeedEnqueueTupleOp
5455 : public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
5456 public:
5457 using OpRewritePattern::OpRewritePattern;
5458
matchAndRewrite(TF::OutfeedEnqueueTupleOp op,PatternRewriter & rewriter) const5459 LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
5460 PatternRewriter &rewriter) const override {
5461 auto token_type = mhlo::TokenType::get(rewriter.getContext());
5462 auto token = rewriter.create<CreateTokenOp>(op.getLoc(), token_type);
5463
5464 rewriter.create<OutfeedOp>(op.getLoc(), token_type, op.inputs(), token,
5465 /*outfeed_config=*/rewriter.getStringAttr(""));
5466 rewriter.eraseOp(op);
5467 return success();
5468 }
5469 };
5470
5471 // Converts tf.TopKV2 to chlo.top_k.
5472 class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
5473 public:
5474 using OpRewritePattern::OpRewritePattern;
5475
matchAndRewrite(TF::TopKV2Op op,PatternRewriter & rewriter) const5476 LogicalResult matchAndRewrite(TF::TopKV2Op op,
5477 PatternRewriter &rewriter) const override {
5478 // We can only match when the `k` operand is a constant scalar.
5479 DenseIntElementsAttr k_attr;
5480 if (!matchPattern(op.k(), m_Constant(&k_attr))) return failure();
5481 int64_t k = (*k_attr.begin()).getSExtValue();
5482
5483 TensorType input_type = op.input().getType().cast<TensorType>();
5484 if (!input_type.hasRank()) return failure();
5485 int64_t input_rank = input_type.getRank();
5486 int64_t last_dim_index = input_rank - 1;
5487 int64_t last_dim_size = input_type.getDimSize(last_dim_index);
5488 if (last_dim_size == ShapedType::kDynamicSize) return failure();
5489
5490 rewriter.replaceOpWithNewOp<chlo::TopKOp>(op, op.input(), k);
5491 return success();
5492 }
5493 };
5494
5495 // Converts tf.Unpack to a series of XLA HLO slice ops.
5496 //
5497 // Each slice takes one element along the dimension to unpack and takes the full
5498 // range for all other dimensions. Each slice is then reshaped to drop the
5499 // dimension to unpack (which is always of size 1).
5500 // TODO(antiagainst): consider changing this into a TF internal lowering pass.
5501 class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
5502 public:
5503 using OpRewritePattern::OpRewritePattern;
5504
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const5505 LogicalResult matchAndRewrite(TF::UnpackOp op,
5506 PatternRewriter &rewriter) const override {
5507 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
5508 if (!value_type) return failure();
5509
5510 int64_t value_rank = value_type.getRank();
5511 int64_t axis = op.axis();
5512 if (axis < 0) axis += value_rank;
5513
5514 // Parameters for constructing each slice.
5515 SmallVector<int64_t, 4> begin_indices(value_rank, 0);
5516 auto end_indices = llvm::to_vector<4>(value_type.getShape());
5517 SmallVector<int64_t, 4> strides(value_rank, 1);
5518
5519 // All HLO slice+squeeze results used to replace the original tf.Unpack op.
5520 SmallVector<Value, 4> results;
5521 results.reserve(op.getNumResults());
5522
5523 for (int i = 0, end = op.getNumResults(); i < end; ++i) {
5524 begin_indices[axis] = i;
5525 end_indices[axis] = i + 1;
5526
5527 auto slice_op = rewriter.create<mhlo::SliceOp>(
5528 op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
5529 GetI64ElementsAttr(end_indices, &rewriter),
5530 GetI64ElementsAttr(strides, &rewriter));
5531 // Reshape to drop the axis dimension.
5532 auto result =
5533 rewriter.create<TF::SqueezeOp>(op.getLoc(), op.getType(i), slice_op,
5534 rewriter.getI64ArrayAttr(op.axis()));
5535 results.push_back(result);
5536 }
5537
5538 rewriter.replaceOp(op, results);
5539 return success();
5540 }
5541 };
5542
5543 // Converts tf.Unpack to a series of XLA HLO Slice ops.
5544 // TODO(disc): To recover static special case's performance with folding and
5545 // canonicalization.
5546 class ConvertUnpackOpDynamic : public OpRewritePattern<TF::UnpackOp> {
5547 public:
5548 using OpRewritePattern::OpRewritePattern;
5549
matchAndRewrite(TF::UnpackOp op,PatternRewriter & rewriter) const5550 LogicalResult matchAndRewrite(TF::UnpackOp op,
5551 PatternRewriter &rewriter) const override {
5552 auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
5553 if (!value_type) return failure();
5554 // TODO(disc): Remove this constraint once fold and canonicalization
5555 // implemented.
5556 if (value_type.hasStaticShape()) return failure();
5557
5558 int64_t value_rank = value_type.getRank();
5559 int64_t axis = op.axis();
5560 if (axis < 0) axis += value_rank;
5561 Location loc = op.getLoc();
5562
5563 auto shape_scalar_type = rewriter.getIntegerType(32);
5564 // Parameters for constructing each slice.
5565 SmallVector<Value, 4> begin_indices, end_indices, strides;
5566 begin_indices.reserve(value_rank);
5567 end_indices.reserve(value_rank);
5568 strides.reserve(value_rank);
5569 // final output shape
5570 SmallVector<Value, 4> shape_values;
5571 shape_values.reserve(value_rank - 1);
5572 // slice shape before reshape, should be like{?, 1, ?, ?} if axis = 1
5573 SmallVector<int64_t, 4> slice_shape(value_rank, ShapedType::kDynamicSize);
5574 for (int64_t dim_idx = 0; dim_idx < value_rank; ++dim_idx) {
5575 int64_t dim_size = value_type.getDimSize(dim_idx);
5576 if (dim_size == ShapedType::kDynamicSize) {
5577 Value dim_i = rewriter.create<arith::IndexCastOp>(
5578 loc, shape_scalar_type,
5579 rewriter.create<tensor::DimOp>(loc, op.getOperand(), dim_idx));
5580 end_indices.push_back(dim_i);
5581 if (dim_idx != axis) {
5582 shape_values.push_back(dim_i);
5583 }
5584 } else {
5585 Value dim_i = rewriter.create<arith::ConstantOp>(
5586 loc, shape_scalar_type,
5587 rewriter.getIntegerAttr(shape_scalar_type, dim_size));
5588 end_indices.push_back(dim_i);
5589 if (dim_idx != axis) {
5590 shape_values.push_back(dim_i);
5591 slice_shape[dim_idx] = dim_size;
5592 } else {
5593 slice_shape[dim_idx] = 1;
5594 }
5595 }
5596 begin_indices.push_back(
5597 rewriter.create<arith::ConstantIntOp>(loc, 0, 32));
5598 strides.push_back(rewriter.create<arith::ConstantIntOp>(loc, 1, 32));
5599 }
5600
5601 SmallVector<Value, 4> results;
5602 results.reserve(op.getNumResults());
5603 Type i32_ty = rewriter.getI32Type();
5604 for (int64_t i = 0; i < op.getNumResults(); ++i) {
5605 begin_indices[axis] = rewriter.create<arith::ConstantIntOp>(loc, i, 32);
5606 end_indices[axis] = rewriter.create<arith::ConstantIntOp>(loc, i + 1, 32);
5607 Value slice_op = rewriter.create<RealDynamicSliceOp>(
5608 loc, RankedTensorType::get(slice_shape, value_type.getElementType()),
5609 op.value(),
5610 rewriter.create<tensor::FromElementsOp>(
5611 loc,
5612 RankedTensorType::get(
5613 {static_cast<int64_t>(begin_indices.size())}, i32_ty),
5614 begin_indices),
5615 rewriter.create<tensor::FromElementsOp>(
5616 loc,
5617 RankedTensorType::get({static_cast<int64_t>(end_indices.size())},
5618 i32_ty),
5619 end_indices),
5620 rewriter.create<tensor::FromElementsOp>(
5621 loc,
5622 RankedTensorType::get({static_cast<int64_t>(strides.size())},
5623 i32_ty),
5624 strides));
5625 // Reshape to drop the axis dimension.
5626 Value new_shape = rewriter.create<tensor::FromElementsOp>(
5627 loc,
5628 RankedTensorType::get({static_cast<int64_t>(shape_values.size())},
5629 i32_ty),
5630 shape_values);
5631 Value reshape_op = rewriter.create<DynamicReshapeOp>(loc, op.getType(i),
5632 slice_op, new_shape);
5633 results.push_back(reshape_op);
5634 }
5635
5636 rewriter.replaceOp(op, results);
5637 return success();
5638 }
5639 };
5640
5641 // Converts the tf.SigmoidGradOp
5642 // TODO(disc): To recover static special case's performance with folding and
5643 // canonicalization.
5644 class ConvertSigmoidGradOpDynamic : public OpRewritePattern<TF::SigmoidGradOp> {
5645 public:
5646 using OpRewritePattern::OpRewritePattern;
5647
matchAndRewrite(TF::SigmoidGradOp op,PatternRewriter & rewriter) const5648 LogicalResult matchAndRewrite(TF::SigmoidGradOp op,
5649 PatternRewriter &rewriter) const override {
5650 Location loc = op.getLoc();
5651 Value y = op.y();
5652 Value dy = op.dy();
5653 auto tp_y = y.getType().dyn_cast<RankedTensorType>();
5654 auto tp_dy = dy.getType().dyn_cast<RankedTensorType>();
5655 if (!tp_y || !tp_dy) return failure();
5656
5657 // TODO(disc): Remove this constraint once fold and canonicalization
5658 // implemented.
5659 if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure();
5660
5661 Attribute attr;
5662 Type elem_tp = tp_y.getElementType();
5663 if (elem_tp.isSignlessInteger()) {
5664 attr = rewriter.getIntegerAttr(elem_tp, 1);
5665 } else {
5666 assert(elem_tp.isa<FloatType>());
5667 attr = rewriter.getFloatAttr(elem_tp, 1);
5668 }
5669 Value one = rewriter.create<mhlo::ConstantOp>(
5670 loc, DenseElementsAttr::get(RankedTensorType::get({}, elem_tp), attr));
5671
5672 auto v0 = rewriter.create<chlo::BroadcastMulOp>(
5673 loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y));
5674 auto v1 = rewriter.create<chlo::BroadcastSubOp>(
5675 loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y));
5676 auto result = rewriter.create<chlo::BroadcastMulOp>(
5677 loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1));
5678
5679 rewriter.replaceOp(op, result.getOperation()->getResults());
5680 return success();
5681 }
5682 };
5683
5684 // Converts TF unsorted segment reduction ops to XLA HLO scatter op.
5685 //
5686 // TF unsorted segment reduction op peforms the following calculation:
5687 //
5688 // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's shape is
5689 // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's
5690 // shape, so we can have data's shape represented as [SI0, SI1, ..., SIm,
5691 // Dm+1, ..., Dn]. Then
5692 // output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] =
5693 // <ReductionOp> over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in]
5694 // where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN).
5695 //
5696 // The op will be translated to XLA HLO scatter with the following parameters:
5697 // * Update window dims is [segment_id_rank, data_rank).
5698 // * Inserted window dims is {0}.
5699 // * Scatter dims to operand dims mapping is {0}.
5700 // * Index vector dim is segment_id_rank.
5701 template <typename ConcreteClass, typename OpTy, typename ReductionOp>
5702 class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
5703 using OpRewritePattern<OpTy>::OpRewritePattern;
5704
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const5705 LogicalResult matchAndRewrite(OpTy op,
5706 PatternRewriter &rewriter) const override {
5707 auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
5708 if (!data_type) return failure();
5709 int64_t data_rank = data_type.getRank();
5710
5711 auto segment_ids_type =
5712 op.segment_ids().getType().template dyn_cast<RankedTensorType>();
5713 if (!segment_ids_type) return failure();
5714 int64_t segment_ids_rank = segment_ids_type.getRank();
5715
5716 DenseIntElementsAttr num_segments_attr;
5717 if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr)))
5718 return failure();
5719
5720 // The final shape for TF unsorted segment reduction op is [num_segments] +
5721 // data_shape[segment_ids_rank:].
5722 SmallVector<int64_t, 4> output_shape;
5723 output_shape.push_back((*num_segments_attr.begin()).getSExtValue());
5724 auto suffix = data_type.getShape().drop_front(segment_ids_rank);
5725 output_shape.append(suffix.begin(), suffix.end());
5726 auto output_type =
5727 RankedTensorType::get(output_shape, data_type.getElementType());
5728
5729 // Broadcast the initial value for reduction. This will become the
5730 // 'operand' parameter to scatter to for the final scatter op.
5731 Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
5732 op.getLoc(), &rewriter);
5733 auto broadcasted_init = rewriter.create<mhlo::BroadcastOp>(
5734 op.getLoc(), output_type, init,
5735 GetI64ElementsAttr(output_shape, &rewriter));
5736
5737 // Parameters for the generated scatter op.
5738 SmallVector<int64_t, 1> inserted_window_dims(1, 0);
5739 SmallVector<int64_t, 1> scatter_dims_to_operand_dims(1, 0);
5740 int64_t index_vector_dim = segment_ids_rank;
5741
5742 // Put all parameters in a StructAttr.
5743 auto dims_attr = ScatterDimensionNumbersAttr::get(
5744 rewriter.getContext(),
5745 llvm::to_vector<4>(llvm::seq<int64_t>(segment_ids_rank, data_rank)),
5746 inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim);
5747
5748 auto scatter = rewriter.create<ScatterOp>(
5749 op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)),
5750 op.segment_ids(), op.data(), dims_attr);
5751 BuildReduceBody<ReductionOp>(data_type.getElementType(),
5752 &scatter.update_computation(), &rewriter);
5753
5754 rewriter.replaceOp(op, scatter.getResult(0));
5755 return success();
5756 }
5757 };
5758
5759 class ConvertUnsortedSegmentMaxOp
5760 : public GenericConvertUnsortedSegmentReductionOp<
5761 ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> {
5762 public:
5763 using GenericConvertUnsortedSegmentReductionOp::
5764 GenericConvertUnsortedSegmentReductionOp;
5765
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)5766 static Value GetInitialValue(Type reduce_element_type, Location loc,
5767 PatternRewriter *rewriter) {
5768 return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest,
5769 rewriter);
5770 }
5771 };
5772
5773 class ConvertUnsortedSegmentMinOp
5774 : public GenericConvertUnsortedSegmentReductionOp<
5775 ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> {
5776 public:
5777 using GenericConvertUnsortedSegmentReductionOp::
5778 GenericConvertUnsortedSegmentReductionOp;
5779
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)5780 static Value GetInitialValue(Type reduce_element_type, Location loc,
5781 PatternRewriter *rewriter) {
5782 return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax,
5783 rewriter);
5784 }
5785 };
5786
5787 class ConvertUnsortedSegmentProdOp
5788 : public GenericConvertUnsortedSegmentReductionOp<
5789 ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> {
5790 public:
5791 using GenericConvertUnsortedSegmentReductionOp::
5792 GenericConvertUnsortedSegmentReductionOp;
5793
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)5794 static Value GetInitialValue(Type reduce_element_type, Location loc,
5795 PatternRewriter *rewriter) {
5796 return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter);
5797 }
5798 };
5799
5800 class ConvertUnsortedSegmentSumOp
5801 : public GenericConvertUnsortedSegmentReductionOp<
5802 ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> {
5803 public:
5804 using GenericConvertUnsortedSegmentReductionOp::
5805 GenericConvertUnsortedSegmentReductionOp;
5806
GetInitialValue(Type reduce_element_type,Location loc,PatternRewriter * rewriter)5807 static Value GetInitialValue(Type reduce_element_type, Location loc,
5808 PatternRewriter *rewriter) {
5809 return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter);
5810 }
5811 };
5812
5813 // Converts tf.RandomShuffle op into a series of XLA HLO ops.
5814 //
5815 // tf.RandomShuffle shuffles tensors along the first dimension. If the input
5816 // tensor's rank is 1, then it is translated into HLO sort op(s) according to
5817 // indices randomly generated via HLO rng_uniform ops. Otherwise, it is
5818 // translated into an HLO while op to first emulate shuffling indices using
5819 // HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather
5820 // with the shuffled indices.
5821 class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
5822 public:
5823 using OpRewritePattern::OpRewritePattern;
5824
matchAndRewrite(TF::RandomShuffleOp op,PatternRewriter & rewriter) const5825 LogicalResult matchAndRewrite(TF::RandomShuffleOp op,
5826 PatternRewriter &rewriter) const override {
5827 auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
5828 if (!input_type) return failure();
5829
5830 int64_t input_rank = input_type.getRank();
5831 int64_t first_dim_size = input_type.getDimSize(0);
5832 if (ShapedType::isDynamic(first_dim_size)) return failure();
5833
5834 // We are shuffling along the first dimension. If its size is <= 1, then
5835 // shuffling is a no-op.
5836 if (first_dim_size <= 1) {
5837 rewriter.replaceOp(op, op.value());
5838 return success();
5839 }
5840
5841 // For vectors, shuffle values by sorting instead of the obvious
5842 // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct,
5843 // but not easily parallelizable. For a sufficiently parallel architecture,
5844 // it is faster to sort many times, than Fisher-Yates shuffle once.
5845 if (input_rank == 1) {
5846 // Shuffle values by assigning each value a random key and sorting the
5847 // keys. Keys can collide causing detectable patterns in the shuffled
5848 // output. Collisions translates into more ascending sub-sequences in the
5849 // shuffled output than would be expected by chance. To avoid collisions,
5850 // the number of possible key values must be sufficiently large.
5851
5852 // How are more than 2^32 keys created? In each loop iteration, the
5853 // algorithm sorts by random keys. Conceptually, the earlier iterations
5854 // are sorting on the lower-order bits of larger keys that are never
5855 // actually assembled.
5856
5857 // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
5858 // the number of possible keys and n is the number of values. If d = n^2,
5859 // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
5860 // as n goes to infinity is zero.
5861
5862 // This implementation ensures that the key-space is greater than or equal
5863 // to the cube of the number of values. The risk of collisions can be
5864 // further reduced by increasing Exponent at the expense of
5865 // performance.
5866
5867 // For Exponent = 2, the expected number of collisions per shuffle is
5868 // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
5869 // about 1/2.
5870
5871 // For Exponent = 3, the expected number of collisions per shuffle is
5872 // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
5873 // about 1/3255.
5874
5875 // For Exponent = 4, the expected number of collisions per shuffle is
5876 // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
5877 // about 1/132622.
5878 constexpr int exponent = 3;
5879 int64_t num_elements = input_type.getNumElements();
5880 uint32_t u32_max = std::numeric_limits<uint32_t>::max();
5881 int rounds =
5882 std::ceil(exponent * std::log(num_elements) / std::log(u32_max));
5883
5884 Value current = op.value();
5885 for (int i = 0; i < rounds; ++i) {
5886 auto keys =
5887 CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
5888 /*upper_limit=*/u32_max, &rewriter);
5889 auto sorted = createSortOp(
5890 &rewriter, op.getLoc(), {keys, current},
5891 {rewriter.getIntegerType(32), input_type.getElementType()},
5892 /*dimension=*/-1, /*is_stable=*/false,
5893 /*direction=*/ComparisonDirection::LT);
5894 current = sorted.getResult(1);
5895 }
5896 rewriter.replaceOp(op, current);
5897 return success();
5898 }
5899
5900 // The Fisher-Yates algorithm.
5901
5902 // Generate range(n) as the initial value for the indices to be swapped.
5903 auto indices_type =
5904 RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
5905 Value indices = rewriter.create<mhlo::IotaOp>(
5906 op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0));
5907
5908 // Generate random numbers to be used as swaps for the indices.
5909 Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0,
5910 first_dim_size, &rewriter);
5911
5912 // While loop body to perform index swaps.
5913 auto swap_body_fn = [&](Location loc, Value i, ArrayRef<Value> old_values,
5914 SmallVectorImpl<Value> *new_values,
5915 OpBuilder *builder) {
5916 Value swaps = old_values[0];
5917 Value indices = old_values[1];
5918
5919 auto scalar_i32_type =
5920 RankedTensorType::get({}, builder->getIntegerType(32));
5921 auto scalar_i64_type =
5922 RankedTensorType::get({}, builder->getIntegerType(64));
5923
5924 auto scalar_one =
5925 DenseIntElementsAttr::get(scalar_i64_type, ArrayRef<int64_t>(1));
5926
5927 // We need to swap the indices[i] with indices[swaps[i]]. First get
5928 // these index values.
5929 Value source_index =
5930 builder->create<mhlo::DynamicSliceOp>(loc, indices, i, scalar_one);
5931 Value swap_index = builder->create<mhlo::ReshapeOp>(
5932 loc, scalar_i32_type,
5933 builder->create<mhlo::DynamicSliceOp>(loc, swaps, i, scalar_one));
5934 Value target_index = builder->create<mhlo::DynamicSliceOp>(
5935 loc, indices, swap_index, scalar_one);
5936
5937 // Then perform the swap.
5938 // indices[i] <- indices[swaps[i]]
5939 indices = builder->create<mhlo::DynamicUpdateSliceOp>(
5940 loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
5941 // indices[swaps[i]] <- indices[i]
5942 indices = builder->create<mhlo::DynamicUpdateSliceOp>(
5943 loc, indices.getType(), indices, source_index,
5944 llvm::makeArrayRef(swap_index));
5945
5946 // Update new values.
5947 new_values->assign({swaps, indices});
5948 };
5949
5950 // Create a while op to swap indices.
5951 SmallVector<Value, 2> while_output;
5952 CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices},
5953 &while_output, &rewriter);
5954 Value swaped_indices = while_output[1];
5955
5956 // Gather the data using the swapped indices as the shuffled order.
5957 ArrayRef<int64_t> input_shape = input_type.getShape();
5958 SmallVector<int64_t, 4> slice_sizes(input_shape.begin(), input_shape.end());
5959 slice_sizes[0] = 1;
5960 auto dims_attr = GatherDimensionNumbersAttr::get(
5961 rewriter.getContext(),
5962 /*offset_dims=*/llvm::to_vector<4>(llvm::seq<int64_t>(1, input_rank)),
5963 /*collapsed_slice_dims=*/{0},
5964 /*start_index_map=*/{0},
5965 /*index_vector_dim=*/1);
5966 rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
5967 op, op.getType(), op.value(), swaped_indices, dims_attr,
5968 GetI64ElementsAttr(slice_sizes, &rewriter));
5969
5970 return success();
5971 }
5972 };
5973
5974 // Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
5975 class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
5976 public:
5977 using OpRewritePattern::OpRewritePattern;
5978
matchAndRewrite(TF::XlaShardingOp op,PatternRewriter & rewriter) const5979 LogicalResult matchAndRewrite(TF::XlaShardingOp op,
5980 PatternRewriter &rewriter) const override {
5981 // TODO(b/148313088): define sharding attribute struct in MLIR intead of
5982 // using a string.
5983 if (!op._XlaSharding().has_value()) return failure();
5984
5985 NamedAttribute call_target_name = rewriter.getNamedAttr(
5986 "call_target_name", rewriter.getStringAttr("Sharding"));
5987
5988 auto custom_call = rewriter.create<mhlo::CustomCallOp>(
5989 op.getLoc(), op.getType(), op.input(),
5990 ArrayRef<NamedAttribute>{call_target_name});
5991 custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
5992 rewriter.replaceOp(op, custom_call.getResult(0));
5993
5994 return success();
5995 }
5996 };
5997
5998 // Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO.
5999 class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
6000 public:
6001 using OpRewritePattern::OpRewritePattern;
6002
matchAndRewrite(TF::InplaceUpdateOp op,PatternRewriter & rewriter) const6003 LogicalResult matchAndRewrite(TF::InplaceUpdateOp op,
6004 PatternRewriter &rewriter) const override {
6005 auto input = op.x();
6006 auto indices = op.i();
6007 auto updates = op.v();
6008
6009 // Slice each row of `i` and `v` to perform a separate dynamic-update-slice
6010 // on the contents of `x`.
6011 auto input_type = input.getType().cast<ShapedType>();
6012 auto updates_type = updates.getType().cast<ShapedType>();
6013 auto indices_type = indices.getType().cast<ShapedType>();
6014 if (!input_type.hasRank()) return failure();
6015 if (!updates_type.hasRank() || updates_type.isDynamicDim(0))
6016 return failure();
6017 if (!indices_type.hasStaticShape()) return failure();
6018
6019 if (indices_type.getRank() != 1) return failure();
6020
6021 SmallVector<Type, 4> unpacked_indices_type(
6022 indices_type.getDimSize(0),
6023 RankedTensorType::get({}, indices_type.getElementType()));
6024 // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are
6025 // required to have matching types. This rewrite rule creates
6026 // DynamicUpdateSlice ops where the first "start index" is always i32 and
6027 // subsequent ones are constructed based on zero_attr. Thus the type
6028 // for zero_attr needs to be i32 as well.
6029 auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0);
6030 auto unpacked_indices = rewriter.create<TF::UnpackOp>(
6031 op.getLoc(), unpacked_indices_type, indices, zero_attr);
6032
6033 SmallVector<int64_t, 4> split_updates_shape;
6034 split_updates_shape.append(updates_type.getShape().begin(),
6035 updates_type.getShape().end());
6036 split_updates_shape.front() = 1;
6037 SmallVector<Type, 4> split_updates_type;
6038 split_updates_type.resize(
6039 updates_type.getShape().front(),
6040 RankedTensorType::get(split_updates_shape,
6041 updates_type.getElementType()));
6042
6043 auto cst =
6044 rewriter.create<mhlo::ConstantOp>(op.getLoc(), zero_attr).getResult();
6045 auto split_updates = rewriter.create<TF::SplitOp>(
6046 op.getLoc(), split_updates_type, cst, updates);
6047
6048 SmallVector<Value, 6> input_indices;
6049 input_indices.resize(input_type.getRank(), cst);
6050
6051 for (auto pair :
6052 llvm::zip(unpacked_indices.output(), split_updates.output())) {
6053 input_indices.front() = std::get<0>(pair);
6054 input = rewriter.create<mhlo::DynamicUpdateSliceOp>(
6055 op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
6056 }
6057
6058 rewriter.replaceOp(op, input);
6059 return success();
6060 }
6061 };
6062
6063 // Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO.
6064 class ConvertXlaDynamicUpdateSliceOp
6065 : public OpRewritePattern<TF::XlaDynamicUpdateSliceOp> {
6066 public:
6067 using OpRewritePattern::OpRewritePattern;
6068
matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,PatternRewriter & rewriter) const6069 LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op,
6070 PatternRewriter &rewriter) const override {
6071 auto indices_type = op.indices().getType().dyn_cast<RankedTensorType>();
6072 if (!indices_type || !indices_type.hasStaticShape() ||
6073 indices_type.getShape().size() != 1)
6074 return failure();
6075
6076 SmallVector<Type, 4> unpacked_indices_type(
6077 indices_type.getDimSize(0),
6078 RankedTensorType::get({}, indices_type.getElementType()));
6079 auto unpacked_indices = rewriter.create<TF::UnpackOp>(
6080 op.getLoc(), unpacked_indices_type, op.indices(),
6081 IntegerAttr::get(rewriter.getIntegerType(64), 0));
6082 rewriter.replaceOpWithNewOp<mhlo::DynamicUpdateSliceOp>(
6083 op, op.getType(), op.input(), op.update(), unpacked_indices.output());
6084 return success();
6085 }
6086 };
6087
6088 // Converts a TF XlaReduceScatter op to ReduceScatter HLO.
6089 class ConvertXlaReduceScatterOp
6090 : public OpRewritePattern<TF::XlaReduceScatterOp> {
6091 using OpRewritePattern::OpRewritePattern;
6092
matchAndRewrite(TF::XlaReduceScatterOp op,PatternRewriter & rewriter) const6093 LogicalResult matchAndRewrite(TF::XlaReduceScatterOp op,
6094 PatternRewriter &rewriter) const override {
6095 DenseIntElementsAttr group_assignment;
6096 if (!matchPattern(op.group_assignment(), m_Constant(&group_assignment)))
6097 return failure();
6098 auto replica_groups =
6099 hlo::convertElementsAttr(group_assignment, rewriter.getIntegerType(64))
6100 .cast<DenseIntElementsAttr>();
6101 if (replica_groups.getType().getRank() != 2) return failure();
6102
6103 APInt scatter_dimension;
6104 if (!matchPattern(op.scatter_dimension(),
6105 m_ConstantInt(&scatter_dimension)))
6106 return failure();
6107
6108 Location loc = op.getLoc();
6109 Type element_type = getElementTypeOrSelf(op.input().getType());
6110
6111 auto reduce_scatter = rewriter.create<ReduceScatterOp>(
6112 loc, op.getType(), op.input(),
6113 rewriter.getIntegerAttr(rewriter.getIntegerType(64),
6114 scatter_dimension.getSExtValue()),
6115 replica_groups, ChannelHandleAttr());
6116 StringRef reduce_op = op.reduce_op();
6117 if (reduce_op == "Add") {
6118 BuildReduceBody<AddOp>(element_type, &reduce_scatter.computation(),
6119 &rewriter);
6120 } else if (reduce_op == "Mul") {
6121 BuildReduceBody<MulOp>(element_type, &reduce_scatter.computation(),
6122 &rewriter);
6123 } else if (reduce_op == "Min") {
6124 BuildReduceBody<MinOp>(element_type, &reduce_scatter.computation(),
6125 &rewriter);
6126 } else if (reduce_op == "Max") {
6127 BuildReduceBody<MaxOp>(element_type, &reduce_scatter.computation(),
6128 &rewriter);
6129 } else {
6130 // For mean, add replicas in the same group. Then divide the sum by the
6131 // number of replicas in each group below.
6132 assert(reduce_op == "Mean");
6133 BuildReduceBody<AddOp>(element_type, &reduce_scatter.computation(),
6134 &rewriter);
6135 }
6136 Value result = reduce_scatter.getResult();
6137
6138 // For mean, divide the merge result by group size.
6139 if (reduce_op == "Mean") {
6140 int64_t replica_group_size = replica_groups.getType().getDimSize(1);
6141 if (replica_group_size == 0) return failure();
6142 auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size,
6143 &rewriter);
6144 auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
6145 result = rewriter.create<chlo::BroadcastDivOp>(
6146 loc, result, divisor.getResult(), broadcast_dims);
6147 }
6148
6149 rewriter.replaceOp(op, {result});
6150 return success();
6151 }
6152 };
6153
6154 // Converts tf.XlaReduceWindow to mhlo.ReduceWindow
6155 class ConvertXlaReduceWindowOp
6156 : public OpRewritePattern<TF::XlaReduceWindowOp> {
6157 using OpRewritePattern::OpRewritePattern;
6158
matchAndRewrite(TF::XlaReduceWindowOp op,PatternRewriter & rewriter) const6159 LogicalResult matchAndRewrite(TF::XlaReduceWindowOp op,
6160 PatternRewriter &rewriter) const override {
6161 DenseElementsAttr window_dimensions, window_strides, base_dilations,
6162 window_dilations, padding;
6163 if (!(matchPattern(op.window_dimensions(),
6164 m_Constant(&window_dimensions)) &&
6165 matchPattern(op.window_strides(), m_Constant(&window_strides)) &&
6166 matchPattern(op.base_dilations(), m_Constant(&base_dilations)) &&
6167 matchPattern(op.window_dilations(), m_Constant(&window_dilations)) &&
6168 matchPattern(op.padding(), m_Constant(&padding))))
6169 return failure();
6170
6171 Location loc = op.getLoc();
6172
6173 SmallVector<Type> result_types{op.getResult().getType()};
6174 // Create the mhlo.SelectAndScatter op.
6175 auto reduce_window_op = rewriter.create<mhlo::ReduceWindowOp>(
6176 loc, result_types, op.input(), op.init_value(),
6177 hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64))
6178 .cast<DenseIntElementsAttr>(),
6179 hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64))
6180 .cast<DenseIntElementsAttr>(),
6181 hlo::convertElementsAttr(base_dilations, rewriter.getIntegerType(64))
6182 .cast<DenseIntElementsAttr>(),
6183 hlo::convertElementsAttr(window_dilations, rewriter.getIntegerType(64))
6184 .cast<DenseIntElementsAttr>(),
6185 hlo::convertElementsAttr(padding, rewriter.getIntegerType(64))
6186 .cast<DenseIntElementsAttr>());
6187 // Insert a call to the reducer in the region of the mhlo op.
6188 mlir::SymbolRefAttr func = op.computation();
6189 auto func_op = cast<mlir::func::FuncOp>(SymbolTable::lookupSymbolIn(
6190 op->getParentOfType<mlir::ModuleOp>(), func));
6191 auto func_ty = func_op.getFunctionType();
6192 BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_window_op.body());
6193
6194 rewriter.replaceOp(op, reduce_window_op.getResults());
6195
6196 return success();
6197 }
6198 };
6199
6200 // Converts ClipByValue to XLA's clamp operation. Includes the broadcasting
6201 // semantics for static and dynamic cases.
6202 class ConvertClipByValueOp : public OpRewritePattern<TF::ClipByValueOp> {
6203 public:
6204 using OpRewritePattern::OpRewritePattern;
6205
matchAndRewrite(TF::ClipByValueOp op,PatternRewriter & rewriter) const6206 LogicalResult matchAndRewrite(TF::ClipByValueOp op,
6207 PatternRewriter &rewriter) const override {
6208 Value input = op.t();
6209 Value min = op.clip_value_min();
6210 Value max = op.clip_value_max();
6211
6212 auto input_ty = input.getType().cast<ShapedType>();
6213 auto min_ty = min.getType().cast<ShapedType>();
6214 auto max_ty = max.getType().cast<ShapedType>();
6215
6216 if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) {
6217 return failure();
6218 }
6219
6220 auto shape = rewriter.create<TF::ShapeOp>(
6221 op.getLoc(),
6222 RankedTensorType::get({input_ty.getRank()}, rewriter.getI32Type()),
6223 input);
6224
6225 if (min_ty != input_ty) {
6226 min =
6227 rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, min, shape);
6228 }
6229
6230 if (max_ty != input_ty) {
6231 max =
6232 rewriter.create<TF::BroadcastToOp>(op.getLoc(), input_ty, max, shape);
6233 }
6234
6235 rewriter.replaceOpWithNewOp<mhlo::ClampOp>(op, input_ty, min, input, max);
6236 return success();
6237 }
6238 };
6239
6240 // Converts ConstOp to XLA's constant operation and introduces a tensor cast if
6241 // needed.
6242 class ConvertConstOp : public OpRewritePattern<TF::ConstOp> {
6243 public:
6244 using OpRewritePattern::OpRewritePattern;
6245
matchAndRewrite(TF::ConstOp op,PatternRewriter & rewriter) const6246 LogicalResult matchAndRewrite(TF::ConstOp op,
6247 PatternRewriter &rewriter) const override {
6248 // Convert only for valid HLO tensors.
6249 auto ty = op.getType().dyn_cast<TensorType>();
6250 if (!ty || !ty.getElementType().isa<FloatType, IntegerType, ComplexType>())
6251 return failure();
6252
6253 Location loc = op.getLoc();
6254 Value result = rewriter.create<mhlo::ConstantOp>(loc, op.value());
6255 if (result.getType() != op.getType())
6256 result = rewriter.create<tensor::CastOp>(loc, op.getType(), result);
6257 rewriter.replaceOp(op, result);
6258 return success();
6259 }
6260 };
6261
6262 // Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by
6263 // setting appropriate window dimensions, with the given aggregation op as the
6264 // reduction function. The input tensor needs to have a static shape, and 'axis'
6265 // must be const. The TableGen pattern is not used for this rewrite because it
6266 // involves regions.
6267 template <typename OpT, typename AggregationOp>
6268 class ConvertCumOp : public OpRewritePattern<OpT> {
6269 using OpRewritePattern<OpT>::OpRewritePattern;
6270
matchAndRewrite(OpT op,PatternRewriter & rewriter) const6271 LogicalResult matchAndRewrite(OpT op,
6272 PatternRewriter &rewriter) const override {
6273 auto input = op.x();
6274 auto input_type = input.getType().template dyn_cast<ShapedType>();
6275 if (!input_type || !input_type.hasStaticShape()) {
6276 return failure();
6277 }
6278
6279 ArrayRef<int64_t> input_shape = input_type.getShape();
6280 int64_t rank = input_shape.size();
6281
6282 // We can only match when the axis is a constant scalar.
6283 DenseIntElementsAttr axis_attr;
6284 if (!matchPattern(op.axis(), m_Constant(&axis_attr))) {
6285 return failure();
6286 }
6287
6288 // Get the dimension to apply the reduction on, and offset properly if it is
6289 // negative.
6290 int64_t axis = (*axis_attr.begin()).getSExtValue();
6291 if (axis < 0) {
6292 axis += rank;
6293 }
6294
6295 // If we're supposed to sum things up in the reverse direction, we reverse
6296 // the input and then later reverse the output.
6297 if (op.reverse()) {
6298 llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
6299 input = rewriter.create<ReverseOp>(
6300 op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter));
6301 }
6302
6303 // Convert if we need to enlarge the element type's bitwidth to avoid
6304 // precision loss.
6305 Type input_element_type = input_type.getElementType();
6306
6307 // TODO(hinsu): Handle complex element types.
6308 if (!input_element_type.isIntOrFloat()) return failure();
6309
6310 Type sum_element_type = GetSumAccumulationType(input_element_type);
6311 input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type);
6312
6313 SmallVector<int64_t, 4> window_dims(rank, 1);
6314 SmallVector<int64_t, 4> window_strides(rank, 1);
6315 window_dims[axis] = input_shape[axis];
6316
6317 SmallVector<int64_t, 8> paddings(rank * 2, 0);
6318 paddings[axis * 2] =
6319 std::max(input_shape[axis] - 1, static_cast<int64_t>(0));
6320 auto paddings_attr = DenseIntElementsAttr::get(
6321 RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)),
6322 paddings);
6323
6324 int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1;
6325 Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value,
6326 &rewriter);
6327
6328 auto reduce = rewriter.create<ReduceWindowOp>(
6329 op.getLoc(), input.getType(), input, init,
6330 GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)),
6331 GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)),
6332 /*base_dilations=*/DenseIntElementsAttr(),
6333 /*window_dilations=*/DenseIntElementsAttr(), paddings_attr);
6334 BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter);
6335 Value result = reduce.getResult(0);
6336
6337 if (op.exclusive()) {
6338 // In "exclusive" operation, the output will start with the "init" (0)
6339 // values. There is no way to express that as a ReduceWindowOp, so run the
6340 // normal operation, and then use a PadOp to add the 0 "column" on the
6341 // left and cut away the last column on the right.
6342 llvm::SmallVector<int64_t, 4> low_padding(rank, 0);
6343 llvm::SmallVector<int64_t, 4> high_padding(rank, 0);
6344 llvm::SmallVector<int64_t, 4> interior_padding(rank, 0);
6345 low_padding[axis] = 1;
6346 high_padding[axis] = -1;
6347 result = rewriter.create<PadOp>(
6348 op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter),
6349 GetI64ElementsAttr(high_padding, &rewriter),
6350 GetI64ElementsAttr(interior_padding, &rewriter));
6351 }
6352
6353 // Convert back if we enlarged the element type's bitwidth.
6354 result =
6355 rewriter.create<ConvertOp>(op.getLoc(), result, input_element_type);
6356
6357 if (op.reverse()) {
6358 llvm::SmallVector<int64_t, 4> dims_to_reverse({axis});
6359 result = rewriter.create<ReverseOp>(
6360 op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter));
6361 }
6362
6363 rewriter.replaceOp(op, result);
6364 return success();
6365 }
6366 };
6367
6368 using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>;
6369 using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>;
6370
6371 // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard
6372 // dialect lowerings. This involves extracting the shape type, extracting and
6373 // converting each dimension to a known integer type, and repacking into a final
6374 // tensor.
6375 class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
6376 public:
6377 using OpRewritePattern::OpRewritePattern;
6378
matchAndRewrite(TF::ShapeOp op,PatternRewriter & rewriter) const6379 LogicalResult matchAndRewrite(TF::ShapeOp op,
6380 PatternRewriter &rewriter) const override {
6381 Value input = op.input();
6382
6383 auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
6384 if (!result_ty) {
6385 return failure();
6386 }
6387
6388 auto index_tensor =
6389 RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
6390 auto shape_op =
6391 rewriter.create<shape::ShapeOfOp>(op.getLoc(), index_tensor, input);
6392 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, result_ty, shape_op);
6393 return success();
6394 }
6395 };
6396
6397 class ConvertDynamicExpandDimsOp : public OpRewritePattern<TF::ExpandDimsOp> {
6398 public:
6399 using OpRewritePattern::OpRewritePattern;
6400
matchAndRewrite(TF::ExpandDimsOp op,PatternRewriter & rewriter) const6401 LogicalResult matchAndRewrite(TF::ExpandDimsOp op,
6402 PatternRewriter &rewriter) const override {
6403 auto input = op.input();
6404 auto input_ty = input.getType().cast<ShapedType>();
6405 auto result_ty = op.getType().cast<ShapedType>();
6406 if (!result_ty.hasRank() || !input_ty.hasRank() ||
6407 result_ty.hasStaticShape()) {
6408 return failure();
6409 }
6410
6411 DenseIntElementsAttr expand_dims_attr;
6412 if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) {
6413 return failure();
6414 }
6415
6416 auto shape = rewriter.create<shape::ShapeOfOp>(
6417 op.getLoc(),
6418 RankedTensorType::get({input_ty.getRank()}, rewriter.getIndexType()),
6419 input);
6420 auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getValues<APInt>());
6421
6422 llvm::SmallVector<Value, 4> dims;
6423 dims.resize(result_ty.getRank());
6424
6425 auto inserted_dim = expand_dims[0].getSExtValue();
6426
6427 // Handle the negative value use case.
6428 if (inserted_dim < 0) {
6429 inserted_dim += result_ty.getRank();
6430 // This means the value is completely incorrect, just return.
6431 if (inserted_dim < 0) {
6432 return failure();
6433 }
6434 }
6435
6436 dims[inserted_dim] =
6437 rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 1);
6438
6439 for (int i = 0; i < dims.size() - 1; i++) {
6440 // Add the extracted dim.
6441 Value index = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), i);
6442 Value dim = rewriter.create<tensor::ExtractOp>(op.getLoc(), shape, index);
6443 dims[i >= inserted_dim ? i + 1 : i] = dim;
6444 }
6445
6446 auto from_extents =
6447 rewriter.create<tensor::FromElementsOp>(op.getLoc(), dims);
6448 rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, input,
6449 from_extents);
6450 return success();
6451 }
6452 };
6453
6454 class ConvertDynamicSqueezeOp : public OpRewritePattern<TF::SqueezeOp> {
6455 public:
6456 using OpRewritePattern::OpRewritePattern;
6457
matchAndRewrite(TF::SqueezeOp op,PatternRewriter & rewriter) const6458 LogicalResult matchAndRewrite(TF::SqueezeOp op,
6459 PatternRewriter &rewriter) const override {
6460 auto input = op.input();
6461 auto input_ty = input.getType().cast<ShapedType>();
6462 auto result_ty = op.getType().cast<ShapedType>();
6463 if (!result_ty.hasRank() || !input_ty.hasRank() ||
6464 result_ty.hasStaticShape()) {
6465 return failure();
6466 }
6467
6468 // The fully dynamic case is unsupported.
6469 if (op.squeeze_dims().empty()) {
6470 return failure();
6471 }
6472
6473 SmallVector<int64_t> squeeze_dims;
6474 int64_t input_rank = input_ty.getRank();
6475 for (const auto &squeeze_dim_apint :
6476 op.squeeze_dims().getAsValueRange<IntegerAttr>()) {
6477 int64_t squeeze_dim = squeeze_dim_apint.getSExtValue();
6478 // Handle negative inputs.
6479 if (squeeze_dim < 0) squeeze_dim += input_rank;
6480 assert(squeeze_dim >= 0 && squeeze_dim < input_rank &&
6481 "squeeze dim out of bounds");
6482
6483 squeeze_dims.push_back(squeeze_dim);
6484 }
6485
6486 // Collect the unsqueezed dimensions.
6487 llvm::SmallVector<Value> dims;
6488 for (int64_t i = 0; i != input_rank; ++i) {
6489 if (llvm::is_contained(squeeze_dims, i)) continue;
6490 dims.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), input, i));
6491 }
6492
6493 auto from_extents =
6494 rewriter.create<tensor::FromElementsOp>(op.getLoc(), dims);
6495 // chlo::DynamicReshapeOp checks if the reshape is legal and will fail if
6496 // any non-1 dimension is squeezed.
6497 rewriter.replaceOpWithNewOp<chlo::DynamicReshapeOp>(op, result_ty, input,
6498 from_extents);
6499 return success();
6500 }
6501 };
6502
6503 // Converts a TF QR op to HLO.
6504 class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
6505 public:
6506 using OpRewritePattern::OpRewritePattern;
6507
matchAndRewrite(TF::QrOp op,PatternRewriter & rewriter) const6508 LogicalResult matchAndRewrite(TF::QrOp op,
6509 PatternRewriter &rewriter) const override {
6510 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van
6511 // Loan. def qr_blocked(a, block_size):
6512 // m = a.shape[0]
6513 // n = a.shape[1]
6514 // q = np.eye(m)
6515 // for i in xrange(0, min(m, n), block_size):
6516 // k = min(block_size, min(m, n) - s)
6517 // (a, vs, taus) = qr(a[i:, i:i+k])
6518 // y = vs
6519 // w = ComputeWYRepresentation(vs, taus, m-i, k)
6520 // a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
6521 // q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
6522 // return (q, a)
6523 auto type = op.input().getType().dyn_cast<RankedTensorType>();
6524 if (!type || !type.hasStaticShape()) return failure();
6525 // The block size is chosen to match old bridge lowering.
6526 constexpr int64_t kBlockSize = 128;
6527 Value a = op.input();
6528 int64_t m = type.getDimSize(type.getRank() - 2);
6529 int64_t n = type.getDimSize(type.getRank() - 1);
6530 int64_t p = std::min(m, n);
6531 auto batch_dims = type.getShape().drop_back(2);
6532 auto iota_type = RankedTensorType::get({m, m}, rewriter.getIntegerType(32));
6533 auto iota0 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
6534 rewriter.getI64IntegerAttr(0));
6535 auto iota1 = rewriter.create<IotaOp>(op.getLoc(), iota_type,
6536 rewriter.getI64IntegerAttr(1));
6537 Value compare = rewriter.create<CompareOp>(op.getLoc(), iota0, iota1,
6538 ComparisonDirection::EQ);
6539 Value identity_matrix =
6540 rewriter.create<ConvertOp>(op.getLoc(), compare, type.getElementType());
6541 auto q_shape = llvm::to_vector<4>(type.getShape());
6542 q_shape.back() = m;
6543 Value q =
6544 rewriter.create<BroadcastOp>(op.getLoc(), identity_matrix,
6545 GetI64ElementsAttr(batch_dims, &rewriter));
6546 auto precision_config = rewriter.getArrayAttr(
6547 {PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST),
6548 PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST)});
6549 for (int64_t i = 0; i < p; i += kBlockSize) {
6550 int64_t k = std::min(kBlockSize, p - i);
6551 auto a_block =
6552 SliceInMinorDims(op.getLoc(), a, {i, i}, {m, i + k}, &rewriter);
6553 Value r_block;
6554 Value taus;
6555 Value vs;
6556 QRBlock(op.getLoc(), a_block, &r_block, &taus, &vs, &rewriter);
6557 a = UpdateSliceInMinorDims(op.getLoc(), a, r_block, {i, i}, &rewriter);
6558
6559 // Compute the I-WY block representation of a product of Householder
6560 // matrices.
6561 Value w =
6562 ComputeWYRepresentation(op.getLoc(), type.getElementType(),
6563 batch_dims, vs, taus, m - i, k, &rewriter);
6564 auto y = vs;
6565
6566 // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
6567 Value a_panel =
6568 SliceInMinorDims(op.getLoc(), a, {i, i + k}, {m, n}, &rewriter);
6569 auto a_update = BatchDot(op.getLoc(), w, true, a_panel, false,
6570 batch_dims.size(), precision_config, &rewriter);
6571 a_update = BatchDot(op.getLoc(), y, false, a_update, false,
6572 batch_dims.size(), precision_config, &rewriter);
6573 a_panel = rewriter.create<AddOp>(op.getLoc(), a_panel, a_update);
6574 a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k},
6575 &rewriter);
6576
6577 // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
6578 Value q_panel =
6579 SliceInMinorDims(op.getLoc(), q, {0, i}, {m, m}, &rewriter);
6580 Value q_update = BatchDot(op.getLoc(), q_panel, false, w, false,
6581 batch_dims.size(), precision_config, &rewriter);
6582 q_update = BatchDot(op.getLoc(), q_update, false, y, true,
6583 batch_dims.size(), precision_config, &rewriter);
6584 q_panel = rewriter.create<AddOp>(op.getLoc(), q_panel, q_update);
6585 q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter);
6586 }
6587 // full_matrices is false when only a partial result in needed. Slice to the
6588 // needed dimensions here.
6589 if (!op.full_matrices()) {
6590 q = SliceInMinorDims(op.getLoc(), q, {0, 0}, {m, p}, &rewriter);
6591 a = SliceInMinorDims(op.getLoc(), a, {0, 0}, {p, n}, &rewriter);
6592 }
6593 rewriter.replaceOp(op, {q, a});
6594 return success();
6595 }
6596
6597 private:
6598 // Computes a Householder reflection of the form:
6599 // H = I - tau v v.T.
6600 // such that
6601 // H . ( x1 ) = ( x1 )
6602 // ( x2 ) = ( x2 )
6603 // ( ... ) = ( ... )
6604 // ( xk ) = ( beta )
6605 // ( ... ) ( 0 )
6606 // ( ... ) ( 0 )
6607 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
6608 // only providing the relevant part of 'x' to maintain XLA's static shape
6609 // invariant. In addition, the implementation supports batching.
6610 // Pseudo-code, without batching:
6611 // alpha = x[k]
6612 // x_copy = np.copy(x)
6613 // x_copy[:k+1] = 0
6614 // xnorm = norm2(x_copy)
6615 // if xnorm == 0:
6616 // beta = alpha
6617 // tau = 0
6618 // v = np.zeros_like(x)
6619 // else:
6620 // beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
6621 // tau = (beta - alpha) / beta
6622 // v = x / (alpha - beta)
6623 // v[k] = 1
6624 // return (v, tau, beta)
House(Location loc,Value x,Value k,ArrayRef<int64_t> batch_dims,const int64_t m,OpBuilder * builder,Value * v,Value * tau,Value * beta) const6625 void House(Location loc, Value x, Value k, ArrayRef<int64_t> batch_dims,
6626 const int64_t m, OpBuilder *builder, Value *v, Value *tau,
6627 Value *beta) const {
6628 auto x_type = x.getType().cast<RankedTensorType>();
6629
6630 llvm::SmallVector<int64_t, 4> batch_dim_ids(batch_dims.size());
6631 std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
6632 const int64_t minor_dim = batch_dims.size();
6633
6634 Value zero = GetScalarConstOfType(x_type.getElementType(), loc, 0, builder);
6635 Value one = GetScalarConstOfType(x_type.getElementType(), loc, 1, builder);
6636
6637 // alpha = x[k]
6638 Value alpha = DynamicSliceInMinorDims(loc, x, {k}, {1}, builder);
6639 alpha = builder->create<ReshapeOp>(
6640 loc, RankedTensorType::get(batch_dims, x_type.getElementType()), alpha);
6641
6642 // Compute x[k+1:] (padded with zeros in elements 0..k)
6643 Value iota = builder->create<IotaOp>(
6644 loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
6645 builder->getI64IntegerAttr(0));
6646 Value gtk = builder->create<chlo::BroadcastCompareOp>(
6647 loc, iota, k, GetI64ElementsAttr({}, builder), ComparisonDirection::GT);
6648 gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
6649 Value x_after_k = builder->create<chlo::BroadcastMulOp>(
6650 loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
6651 Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
6652 // sigma = np.dot(x[k+1:], x[k+1:])
6653 auto sigma = builder->create<ReduceOp>(
6654 loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder));
6655 BuildReduceBody<AddOp>(x_type.getElementType(), &sigma.body(), builder);
6656 // mu = np.sqrt(x[k]*x[k] + sigma)
6657 Value alpha_sq = builder->create<MulOp>(loc, alpha, alpha);
6658 Value mu = builder->create<SqrtOp>(
6659 loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
6660
6661 Value sigma_is_zero = builder->create<chlo::BroadcastCompareOp>(
6662 loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
6663 ComparisonDirection::EQ);
6664 Value alpha_is_negative = builder->create<chlo::BroadcastCompareOp>(
6665 loc, alpha, zero, GetI64ElementsAttr({}, builder),
6666 ComparisonDirection::LT);
6667 auto batch_size_one = builder->create<BroadcastOp>(
6668 loc, one, GetI64ElementsAttr(batch_dims, builder));
6669 Value signed_mu = builder->create<chlo::BroadcastMulOp>(
6670 loc,
6671 builder->create<SelectOp>(loc, alpha_is_negative, batch_size_one,
6672 builder->create<NegOp>(loc, batch_size_one)),
6673 mu, GetI64ElementsAttr({}, builder));
6674 *beta = builder->create<SelectOp>(loc, sigma_is_zero, alpha, signed_mu);
6675 *tau = builder->create<DivOp>(
6676 loc, builder->create<SubtractOp>(loc, *beta, alpha), *beta);
6677 Value zero_tau = builder->create<BroadcastOp>(
6678 loc, zero, GetI64ElementsAttr(batch_dims, builder));
6679 *tau = builder->create<SelectOp>(loc, sigma_is_zero, zero_tau, *tau);
6680 Value divisor = builder->create<SubtractOp>(loc, alpha, *beta);
6681 divisor =
6682 builder->create<SelectOp>(loc, sigma_is_zero, batch_size_one, divisor);
6683
6684 Value eqk = builder->create<chlo::BroadcastCompareOp>(
6685 loc, iota, k, GetI64ElementsAttr({}, builder), ComparisonDirection::EQ);
6686 eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
6687 llvm::SmallVector<int64_t, 4> e_k_shape(batch_dims.size(), 1);
6688 e_k_shape.push_back(m);
6689 auto e_k = builder->create<BroadcastOp>(
6690 loc, eqk,
6691 GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(batch_dims.size(), 1),
6692 builder));
6693
6694 // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
6695 // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
6696 // Note that the add performs a degenerate broadcast.
6697 *v = builder->create<chlo::BroadcastAddOp>(
6698 loc, e_k,
6699 StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
6700 GetI64ElementsAttr(batch_dim_ids, builder),
6701 *builder),
6702 /*broadcast_dimensions=*/nullptr);
6703 }
6704
6705 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
6706 // Loan "Matrix Computations", 4th Edition. This is an unblocked
6707 // implementation used as an inner routine of the blocked implementation.
6708 // Algorithm is adapted slightly so the shapes inside the loop are static, at
6709 // the cost of some redundant computation. Since this is used as an inner
6710 // block kernel, accumulates the Householder transformations (vs, taus) rather
6711 // than the matrix q. Equivalent Python code, without batching: def qr(a):
6712 // m = a.shape[0]
6713 // n = a.shape[1]
6714 // vs = np.zeros([m, n])
6715 // taus = np.zeros([n])
6716 // for j in xrange(min(m, n)):
6717 // v, tau, beta = house(a[:, j], j)
6718 // # Unusually, we apply the Householder transformation to the entirety of
6719 // # a, wasting FLOPs to maintain the static shape invariant that XLA
6720 // # requires. For columns that precede j this has no effect.
6721 // a[:, :] -= tau * np.dot(v[:, np.newaxis],
6722 // np.dot(v[np.newaxis, :], a[:, :]))
6723 // # Form column j explicitly rather than relying on the precision of the
6724 // # Householder update.
6725 // a[j, j] = beta
6726 // a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
6727 // vs[:, j] = v
6728 // taus[j] = tau
6729 // return (q, vs, taus)
QRBlock(Location loc,Value a,Value * r,Value * taus,Value * vs,PatternRewriter * rewriter) const6730 void QRBlock(Location loc, Value a, Value *r, Value *taus, Value *vs,
6731 PatternRewriter *rewriter) const {
6732 auto a_type = a.getType().cast<RankedTensorType>();
6733 const int num_dims = a_type.getRank();
6734 assert(num_dims >= 2 && "Argument to QR must have rank >= 2");
6735
6736 const int64_t m = a_type.getDimSize(a_type.getRank() - 2);
6737 const int64_t n = a_type.getDimSize(a_type.getRank() - 1);
6738
6739 const int64_t num_batch_dims = num_dims - 2;
6740 auto batch_dims = a_type.getShape().take_front(num_batch_dims);
6741 llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
6742 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
6743
6744 auto qr_body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
6745 SmallVectorImpl<Value> *new_values,
6746 OpBuilder *builder) {
6747 auto a = old_values[0];
6748 auto vs = old_values[1];
6749 auto taus = old_values[2];
6750
6751 // v, beta = house(a[:, j], j)
6752 auto x = DynamicSliceInMinorDims(loc, a, {j}, {1}, builder);
6753 auto x_collapsed_shape = llvm::to_vector<4>(batch_dims);
6754 x_collapsed_shape.push_back(m);
6755 auto x_collapsed = builder->create<ReshapeOp>(
6756 loc,
6757 RankedTensorType::get(x_collapsed_shape,
6758 getElementTypeOrSelf(x.getType())),
6759 x);
6760 Value v, tau, beta;
6761 House(loc, x_collapsed, j, batch_dims, m, builder, &v, &tau, &beta);
6762
6763 auto shape = llvm::to_vector<4>(batch_dims);
6764 shape.append({1, m});
6765 auto v_broadcast = builder->create<ReshapeOp>(
6766 loc, RankedTensorType::get(shape, getElementTypeOrSelf(v.getType())),
6767 v);
6768 // a[:, :] -= tau * np.dot(v[:, np.newaxis],
6769 // np.dot(v[np.newaxis, :], a[:, :]))
6770 auto precision = builder->getArrayAttr(
6771 {PrecisionAttr::get(builder->getContext(), Precision::HIGHEST),
6772 PrecisionAttr::get(builder->getContext(), Precision::HIGHEST)});
6773 auto vva = BatchDot(loc, v_broadcast, false, a, false, num_batch_dims,
6774 precision, builder);
6775 vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
6776 precision, builder);
6777 auto tau_x_vva = StaticBinaryBroadcast<mhlo::MulOp>(
6778 loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
6779 *builder);
6780 a = builder->create<SubtractOp>(loc, a, tau_x_vva);
6781
6782 // It is more precise to populate column 'k' explicitly, rather than
6783 // computing it implicitly by applying the Householder transformation.
6784 // a[k,k] = beta
6785 // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
6786 auto iota = builder->create<IotaOp>(
6787 loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
6788 builder->getI64IntegerAttr(0));
6789 Value predecessor_mask = builder->create<chlo::BroadcastCompareOp>(
6790 loc, iota, j, GetI64ElementsAttr({}, builder),
6791 ComparisonDirection::LT);
6792 predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
6793 a_type.getElementType());
6794 Value mask = builder->create<chlo::BroadcastCompareOp>(
6795 loc, iota, j, GetI64ElementsAttr({}, builder),
6796 ComparisonDirection::EQ);
6797 mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
6798 mask = builder->create<BroadcastOp>(
6799 loc,
6800 mask,
6801 GetI64ElementsAttr(llvm::SmallVector<int64_t, 4>(num_batch_dims, 1),
6802 builder));
6803 Value predecessor_masked_x = StaticBinaryBroadcast<MulOp>(
6804 loc, x, predecessor_mask,
6805 GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder);
6806 Value masked_beta = StaticBinaryBroadcast<MulOp>(
6807 loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder),
6808 *builder);
6809 Value new_x =
6810 builder->create<AddOp>(loc, predecessor_masked_x, masked_beta);
6811 // Update a[:,j]
6812 llvm::SmallVector<int64_t, 4> dim_ids(num_dims);
6813 std::iota(dim_ids.begin(), dim_ids.end(), 0);
6814 new_x = builder->create<BroadcastInDimOp>(
6815 loc, a_type, new_x, GetI64ElementsAttr(dim_ids, builder));
6816 const int64_t minor_dim = num_batch_dims;
6817 auto iota_mn = builder->create<IotaOp>(
6818 loc,
6819 RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
6820 builder->getI64IntegerAttr(minor_dim + 1));
6821 Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
6822 loc, iota_mn, j, GetI64ElementsAttr({}, builder),
6823 ComparisonDirection::EQ);
6824 a = builder->create<SelectOp>(loc, xa_mask, new_x, a);
6825
6826 // vs[:, j] = v
6827 llvm::SmallVector<int64_t, 4> vs_broadcast_dims(num_batch_dims + 1);
6828 std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0);
6829 Value vs_zeros =
6830 GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
6831 vs_zeros = builder->create<BroadcastOp>(
6832 loc, vs_zeros,
6833 GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
6834 builder));
6835 auto vs_update = builder->create<SelectOp>(
6836 loc, xa_mask,
6837 StaticBinaryBroadcast<AddOp>(
6838 loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder),
6839 *builder),
6840 vs_zeros);
6841 vs = builder->create<AddOp>(loc, vs, vs_update);
6842
6843 // taus[j] = tau
6844 llvm::SmallVector<int64_t, 4> tau_broadcast_dims(batch_dims.size());
6845 std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
6846
6847 auto iota_shape = llvm::to_vector<4>(batch_dims);
6848 iota_shape.push_back(n);
6849 auto iota_n = builder->create<IotaOp>(
6850 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
6851 builder->getI64IntegerAttr(minor_dim));
6852 Value taus_zeros =
6853 GetScalarConstOfType(a_type.getElementType(), loc, 0, builder);
6854 taus_zeros = builder->create<BroadcastOp>(
6855 loc, taus_zeros,
6856 GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
6857 builder));
6858 Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
6859 loc, iota_n, j, GetI64ElementsAttr({}, builder),
6860 ComparisonDirection::EQ);
6861 auto taus_update = builder->create<SelectOp>(
6862 loc, taus_mask,
6863 StaticBinaryBroadcast<AddOp>(
6864 loc, taus_zeros, tau,
6865 GetI64ElementsAttr(tau_broadcast_dims, builder), *builder),
6866 taus_zeros);
6867 taus = builder->create<AddOp>(loc, taus, taus_update);
6868 new_values->assign({a, vs, taus});
6869 };
6870
6871 Value zero =
6872 GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter);
6873 *vs = rewriter->create<BroadcastOp>(
6874 loc, zero, GetI64ElementsAttr(a_type.getShape(), rewriter));
6875 auto taus_shape = llvm::to_vector<4>(batch_dims);
6876 taus_shape.push_back(n);
6877 *taus = rewriter->create<BroadcastOp>(
6878 loc, zero, GetI64ElementsAttr(taus_shape, rewriter));
6879
6880 SmallVector<Value, 4> while_output;
6881 CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus},
6882 &while_output, rewriter);
6883 *r = while_output[0];
6884 *vs = while_output[1];
6885 *taus = while_output[2];
6886 }
6887
6888 // Computes W and Y such that I-WY is equivalent to the sequence of
6889 // Householder
6890 // transformations given by vs and taus.
6891 // Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
6892 // Y = np.zeros([m, n])
6893 // W = np.zeros([m, n])
6894 // Y[:, 0] = vs[:, 0]
6895 // W[:, 0] = -taus[0] * vs[:, 0]
6896 // for j in xrange(1, n):
6897 // v = vs[:, j]
6898 // z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
6899 // W[:, j] = z
6900 // Y[:, j] = v
6901 // return W
6902 // There is no need to return Y since at termination of the loop it is equal
6903 // to vs.
ComputeWYRepresentation(Location loc,Type data_type,ArrayRef<int64_t> batch_dims,Value vs,Value taus,int64_t m,int64_t n,PatternRewriter * rewriter) const6904 Value ComputeWYRepresentation(Location loc, Type data_type,
6905 ArrayRef<int64_t> batch_dims, Value vs,
6906 Value taus, int64_t m, int64_t n,
6907 PatternRewriter *rewriter) const {
6908 int64_t n_index = batch_dims.size() + 1;
6909 llvm::SmallVector<int64_t, 4> batch_dim_indices(batch_dims.size());
6910 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
6911
6912 auto body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values,
6913 SmallVectorImpl<Value> *new_values, OpBuilder *builder) {
6914 // w has shape [..., m, n]
6915 auto w = old_values[0];
6916 const auto vs = old_values[1];
6917 const auto taus = old_values[2];
6918
6919 // Want j values in range [1, ... n).
6920 j = builder->create<AddOp>(
6921 loc, j,
6922 GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1,
6923 builder));
6924 // vs has shape [..., m, 1]
6925 auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder);
6926 // beta has shape [..., 1]
6927 auto beta = DynamicSliceInMinorDims(loc, taus, {j}, {1}, builder);
6928
6929 auto iota_shape = llvm::to_vector<4>(batch_dims);
6930 iota_shape.append({m, n});
6931 auto iota_mn = builder->create<IotaOp>(
6932 loc, RankedTensorType::get(iota_shape, builder->getIntegerType(32)),
6933 builder->getI64IntegerAttr(n_index));
6934
6935 // y has shape [..., m, n]
6936 Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc,
6937 0, builder);
6938 zero = builder->create<BroadcastOp>(
6939 loc, zero,
6940 GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
6941 builder));
6942 auto compare = builder->create<chlo::BroadcastCompareOp>(
6943 loc, iota_mn, j, GetI64ElementsAttr({}, builder),
6944 ComparisonDirection::GE);
6945 auto y = builder->create<SelectOp>(loc, compare, zero, vs);
6946
6947 // yv has shape [..., n, 1]
6948 auto precision = builder->getArrayAttr(
6949 {PrecisionAttr::get(builder->getContext(), Precision::HIGHEST),
6950 PrecisionAttr::get(builder->getContext(), Precision::HIGHEST)});
6951 auto yv = BatchDot(loc, y, true, v, false, batch_dims.size(), precision,
6952 builder);
6953 // wyv has shape [..., m, 1]
6954 auto wyv = BatchDot(loc, w, false, yv, false, batch_dims.size(),
6955 precision, builder);
6956
6957 // z = -beta * (v + wyv)
6958 auto neg_beta = builder->create<NegOp>(loc, beta);
6959 auto v_wyv = builder->create<AddOp>(loc, v, wyv);
6960 auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
6961 beta_broadcast_dims.push_back(n_index);
6962 auto z = StaticBinaryBroadcast<MulOp>(
6963 loc, neg_beta, v_wyv,
6964 GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter);
6965
6966 w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder);
6967 new_values->assign({w, vs, taus});
6968 };
6969
6970 Value w =
6971 GetScalarConstOfType(getElementTypeOrSelf(data_type), loc, 0, rewriter);
6972 auto w_shape = llvm::to_vector<4>(batch_dims);
6973 w_shape.append({m, n});
6974 w = rewriter->create<BroadcastOp>(loc,
6975 w, GetI64ElementsAttr(w_shape, rewriter));
6976 auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter);
6977 auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter);
6978 auto neg_beta = rewriter->create<NegOp>(loc, beta);
6979 auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices);
6980 beta_broadcast_dims.push_back(n_index);
6981 auto bv = StaticBinaryBroadcast<MulOp>(
6982 loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter),
6983 *rewriter);
6984 w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter);
6985
6986 SmallVector<Value, 4> while_output;
6987 CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter);
6988 return while_output[0];
6989 }
6990 };
6991
6992 // Converts tf.XlaConvV2 to mhlo.Conv
6993 class ConvertXlaConvV2Op : public OpRewritePattern<TF::XlaConvV2Op> {
6994 public:
6995 using OpRewritePattern::OpRewritePattern;
6996
matchAndRewrite(TF::XlaConvV2Op op,PatternRewriter & rewriter) const6997 LogicalResult matchAndRewrite(TF::XlaConvV2Op op,
6998 PatternRewriter &rewriter) const override {
6999 DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr,
7000 rhs_dilation_attr, feature_group_count_attr;
7001 if (!(matchPattern(op.window_strides(), m_Constant(&window_strides_attr)) &&
7002 matchPattern(op.padding(), m_Constant(&padding_attr)) &&
7003 matchPattern(op.lhs_dilation(), m_Constant(&lhs_dilation_attr)) &&
7004 matchPattern(op.rhs_dilation(), m_Constant(&rhs_dilation_attr)) &&
7005 matchPattern(op.feature_group_count(),
7006 m_Constant(&feature_group_count_attr))))
7007 return failure();
7008
7009 auto window_strides_named_attr = rewriter.getNamedAttr(
7010 "window_strides", hlo::convertElementsAttr(window_strides_attr,
7011 rewriter.getIntegerType(64))
7012 .cast<DenseIntElementsAttr>());
7013
7014 auto padding_named_attr = rewriter.getNamedAttr(
7015 "padding",
7016 hlo::convertElementsAttr(padding_attr, rewriter.getIntegerType(64))
7017 .cast<DenseIntElementsAttr>());
7018
7019 auto lhs_dilation_named_attr = rewriter.getNamedAttr(
7020 "lhs_dilation",
7021 hlo::convertElementsAttr(lhs_dilation_attr, rewriter.getIntegerType(64))
7022 .cast<DenseIntElementsAttr>());
7023
7024 auto rhs_dilation_named_attr = rewriter.getNamedAttr(
7025 "rhs_dilation",
7026 hlo::convertElementsAttr(rhs_dilation_attr, rewriter.getIntegerType(64))
7027 .cast<DenseIntElementsAttr>());
7028
7029 int64_t feature_group_count_val =
7030 feature_group_count_attr.getValues<IntegerAttr>()[0].getInt();
7031 auto feature_group_count_named_attr = rewriter.getNamedAttr(
7032 "feature_group_count",
7033 rewriter.getI64IntegerAttr(feature_group_count_val));
7034
7035 auto batch_group_count_named_attr =
7036 rewriter.getNamedAttr("batch_group_count", op.batch_group_countAttr());
7037
7038 xla::ConvolutionDimensionNumbers dnums;
7039 dnums.ParseFromString(op.dimension_numbersAttr().getValue().str());
7040 auto dimension_numbers_named_attr = rewriter.getNamedAttr(
7041 "dimension_numbers",
7042 xla::ConvertConvDimensionNumbers(dnums, &rewriter));
7043
7044 xla::PrecisionConfig precision_config;
7045 precision_config.ParseFromString(
7046 op.precision_configAttr().getValue().str());
7047 auto precision_config_named_attr = rewriter.getNamedAttr(
7048 "precision_config",
7049 xla::ConvertPrecisionConfig(&precision_config, &rewriter));
7050
7051 SmallVector<Value, 2> operands{op.lhs(), op.rhs()};
7052 NamedAttribute attrs[] = {
7053 window_strides_named_attr, padding_named_attr,
7054 lhs_dilation_named_attr, rhs_dilation_named_attr,
7055 feature_group_count_named_attr, batch_group_count_named_attr,
7056 dimension_numbers_named_attr, precision_config_named_attr};
7057 rewriter.replaceOpWithNewOp<mhlo::ConvolutionOp>(op, op.getType(), operands,
7058 llvm::makeArrayRef(attrs));
7059 return success();
7060 }
7061 };
7062
7063 // Converts tf.XlaSelectAndScatter to mhlo.SelectAndScatter
7064 class ConvertXlaSelectAndScatterOp
7065 : public OpRewritePattern<TF::XlaSelectAndScatterOp> {
7066 public:
7067 using OpRewritePattern::OpRewritePattern;
7068
matchAndRewrite(TF::XlaSelectAndScatterOp op,PatternRewriter & rewriter) const7069 LogicalResult matchAndRewrite(TF::XlaSelectAndScatterOp op,
7070 PatternRewriter &rewriter) const override {
7071 ElementsAttr window_dimensions, window_strides, padding;
7072 if (!(matchPattern(op.window_dimensions(),
7073 m_Constant(&window_dimensions)) &&
7074 matchPattern(op.window_strides(), m_Constant(&window_strides)) &&
7075 matchPattern(op.padding(), m_Constant(&padding))))
7076 return failure();
7077
7078 Location loc = op.getLoc();
7079
7080 SmallVector<Type> result_types{op.getResult().getType()};
7081 // Create the mhlo.SelectAndScatter op.
7082 auto select_and_scatter_op = rewriter.create<mhlo::SelectAndScatterOp>(
7083 loc, result_types, op.operand(), op.source(), op.init_value(),
7084 hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64))
7085 .cast<DenseIntElementsAttr>(),
7086 hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64))
7087 .cast<DenseIntElementsAttr>(),
7088 hlo::convertElementsAttr(padding, rewriter.getIntegerType(64))
7089 .cast<DenseIntElementsAttr>());
7090
7091 auto insert_call_to = [&](const mlir::SymbolRefAttr &func, Region *region) {
7092 auto func_op = cast<mlir::func::FuncOp>(SymbolTable::lookupSymbolIn(
7093 op->getParentOfType<mlir::ModuleOp>(), func));
7094 auto func_ty = func_op.getFunctionType();
7095 BuildBodyWithCall(rewriter, loc, func, func_ty, region);
7096 };
7097
7098 // Insert a call to the select function in the select region of the mhlo op.
7099 insert_call_to(op.select(), &select_and_scatter_op.select());
7100 // Insert a call to the scatter function in the scatter region of the mhlo
7101 // op.
7102 insert_call_to(op.scatter(), &select_and_scatter_op.scatter());
7103
7104 rewriter.replaceOp(op, select_and_scatter_op.getResult());
7105
7106 return success();
7107 }
7108 };
7109
7110 // Convert tf.XlaSort to mhlo.Sort
7111 class ConvertXlaSortOp : public OpRewritePattern<TF::XlaSortOp> {
7112 public:
7113 using OpRewritePattern::OpRewritePattern;
7114
matchAndRewrite(TF::XlaSortOp op,PatternRewriter & rewriter) const7115 LogicalResult matchAndRewrite(TF::XlaSortOp op,
7116 PatternRewriter &rewriter) const override {
7117 // Create the sort op.
7118 Type element_type = getElementTypeOrSelf(op.input().getType());
7119 auto sort_op =
7120 createSortOp(&rewriter, op.getLoc(), {op.input()}, {element_type},
7121 /*dimension=*/-1, /*is_stable=*/false,
7122 /*direction=*/ComparisonDirection::LT);
7123 rewriter.replaceOp(op, sort_op.getResult(0));
7124 return success();
7125 }
7126 };
7127
TensorFlowRngAlgToXla(tensorflow::Algorithm alg)7128 inline llvm::Optional<xla::RandomAlgorithm> TensorFlowRngAlgToXla(
7129 tensorflow::Algorithm alg) {
7130 if (alg == tensorflow::RNG_ALG_PHILOX) {
7131 return xla::RandomAlgorithm::RNG_PHILOX;
7132 } else if (alg == tensorflow::RNG_ALG_THREEFRY) {
7133 return xla::RandomAlgorithm::RNG_THREE_FRY;
7134 } else if (alg == tensorflow::RNG_ALG_AUTO_SELECT) {
7135 return xla::RandomAlgorithm::RNG_DEFAULT;
7136 }
7137 return llvm::None;
7138 }
7139
7140 // Converts tf.XlaRngBitGenerator op to mhlo.RngBitGenerator op.
7141 class ConvertXlaRngBitGeneratorOp
7142 : public OpRewritePattern<TF::XlaRngBitGeneratorOp> {
7143 public:
7144 using OpRewritePattern::OpRewritePattern;
7145
matchAndRewrite(TF::XlaRngBitGeneratorOp op,PatternRewriter & rewriter) const7146 LogicalResult matchAndRewrite(TF::XlaRngBitGeneratorOp op,
7147 PatternRewriter &rewriter) const override {
7148 Location loc = op.getLoc();
7149 DenseElementsAttr algorithm;
7150 if (!(matchPattern(op.algorithm(), m_Constant(&algorithm))) ||
7151 algorithm.getType().getRank()) {
7152 return op.emitOpError() << "algorithm must be a constant scalar";
7153 }
7154 auto alg = static_cast<tensorflow::Algorithm>(
7155 algorithm.getValues<IntegerAttr>()[0].getInt());
7156 auto xla_alg = TensorFlowRngAlgToXla(alg);
7157 if (!xla_alg) {
7158 return op.emitOpError() << "unknown algorithm";
7159 }
7160
7161 auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get(
7162 rewriter.getContext(),
7163 *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.getValue()));
7164 auto rng_bit_generator_op = rewriter.create<mhlo::RngBitGeneratorOp>(
7165 loc, op.getResultTypes(), algorithm_attr, op.initial_state());
7166
7167 rewriter.replaceOp(op, rng_bit_generator_op.getResults());
7168
7169 return success();
7170 }
7171 };
7172
7173 // Converts tf.XlaVariadicReduceV2 to mhlo.Reduce
7174 class ConvertXlaVariadicReduceV2Op
7175 : public OpRewritePattern<TF::XlaVariadicReduceV2Op> {
7176 public:
7177 using OpRewritePattern::OpRewritePattern;
7178
matchAndRewrite(TF::XlaVariadicReduceV2Op op,PatternRewriter & rewriter) const7179 LogicalResult matchAndRewrite(TF::XlaVariadicReduceV2Op op,
7180 PatternRewriter &rewriter) const override {
7181 Location loc = op.getLoc();
7182
7183 // Create the mhlo.reduce op.
7184 auto reduce_op = rewriter.create<mhlo::ReduceOp>(
7185 loc, op.inputs(), op.init_values(),
7186 GetI64ElementsAttr(op.dimensions_to_reduce()));
7187 mlir::SymbolRefAttr func = op.reducer();
7188 auto func_op = cast<mlir::func::FuncOp>(SymbolTable::lookupSymbolIn(
7189 op->getParentOfType<mlir::ModuleOp>(), func));
7190 auto func_ty = func_op.getFunctionType();
7191 // Insert a call to the reducer in the region of the mhlo op.
7192 BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_op.body());
7193
7194 rewriter.replaceOp(op, reduce_op.getResults());
7195
7196 return success();
7197 }
7198 };
7199
7200 // Convert tf.XlaVariadicSort to mhlo.Sort
7201 class ConvertXlaVariadicSortOp
7202 : public OpRewritePattern<TF::XlaVariadicSortOp> {
7203 public:
7204 using OpRewritePattern::OpRewritePattern;
7205
matchAndRewrite(TF::XlaVariadicSortOp op,PatternRewriter & rewriter) const7206 LogicalResult matchAndRewrite(TF::XlaVariadicSortOp op,
7207 PatternRewriter &rewriter) const override {
7208 Location loc = op.getLoc();
7209 ElementsAttr dimension;
7210 matchPattern(op.dimension(), m_Constant(&dimension));
7211 // Create the mhlo.sort op.
7212 auto sort_op = rewriter.create<mhlo::SortOp>(
7213 loc, op.inputs(), dimension.getValues<IntegerAttr>()[0].getInt(),
7214 op.is_stable());
7215 mlir::SymbolRefAttr func = op.comparator();
7216 auto func_op = cast<mlir::func::FuncOp>(SymbolTable::lookupSymbolIn(
7217 op->getParentOfType<mlir::ModuleOp>(), func));
7218 auto func_ty = func_op.getFunctionType();
7219 // Insert a call to the reducer in the region of the mhlo op.
7220 BuildBodyWithCall(rewriter, loc, func, func_ty, &sort_op.comparator());
7221
7222 rewriter.replaceOp(op, sort_op.getResults());
7223 return success();
7224 }
7225 };
7226
7227 // Convert tf.XlaReducePrecision to mhlo.ReducePrecision
7228 class ConvertXlaReducePrecisionOp
7229 : public OpRewritePattern<TF::XlaReducePrecisionOp> {
7230 public:
7231 using OpRewritePattern::OpRewritePattern;
7232
matchAndRewrite(TF::XlaReducePrecisionOp op,PatternRewriter & rewriter) const7233 LogicalResult matchAndRewrite(TF::XlaReducePrecisionOp op,
7234 PatternRewriter &rewriter) const override {
7235 IntegerType int32_type = rewriter.getIntegerType(32);
7236 APInt exponent_bits = op.exponent_bitsAttr().getValue();
7237 // Truncating to 32-bits is safe, since pasing any number above the dtype
7238 // size (which is at most 64, for float64) is equivalent to passing the
7239 // dtype size.
7240 IntegerAttr new_exponent_attr =
7241 IntegerAttr::get(int32_type, exponent_bits.truncSSat(32));
7242 APInt mantissa_bits = op.mantissa_bitsAttr().getValue();
7243 IntegerAttr new_mantissa_attr =
7244 IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32));
7245 rewriter.replaceOpWithNewOp<mhlo::ReducePrecisionOp>(
7246 op, op.getType(), op.operand(), new_exponent_attr, new_mantissa_attr);
7247 return success();
7248 }
7249 };
7250
7251 } // end namespace
7252
7253 #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
7254
PopulateLegalizeTfPatterns(MLIRContext * context,RewritePatternSet * patterns)7255 void PopulateLegalizeTfPatterns(MLIRContext *context,
7256 RewritePatternSet *patterns) {
7257 populateWithGenerated(*patterns);
7258 // clang-format off
7259 patterns->add<
7260 ConvertAllOp,
7261 ConvertAnyOp,
7262 ConvertArgMaxOp,
7263 ConvertArgMinOp,
7264 ConvertBatchMatMulV2Op,
7265 ConvertBiasAddOp,
7266 ConvertBroadcastToOp,
7267 ConvertBF16FloorDivOp,
7268 ConvertClipByValueOp,
7269 ConvertConstOp,
7270 ConvertConv2DOp,
7271 ConvertConv3DOp,
7272 ConvertDepthConv2DOp,
7273 ConvertConv2DBackpropFilterOp,
7274 ConvertConv3DBackpropFilterOp,
7275 ConvertConv2DBackpropInputOp,
7276 ConvertConv3DBackpropInputOp,
7277 ConvertCumprodOp,
7278 ConvertCumsumOp,
7279 ConvertDiagPartOp,
7280 ConvertDynamicExpandDimsOp,
7281 ConvertDynamicSqueezeOp,
7282 ConvertEinsumOp,
7283 ConvertRFFTOp,
7284 ConvertIRFFTOp,
7285 ConvertFusedBatchNormGradOp,
7286 ConvertFusedBatchNormGradV2Op,
7287 ConvertFusedBatchNormGradV3Op,
7288 ConvertFusedBatchNormV2Op,
7289 ConvertFusedBatchNormV3Op,
7290 ConvertInfeedDequeueTupleOp,
7291 ConvertIdentityNOp,
7292 ConvertInplaceUpdateOp,
7293 ConvertLinSpaceOp,
7294 ConvertMaxOp,
7295 ConvertMinOp,
7296 ConvertAvgPool2DOp,
7297 ConvertAvgPool3DOp,
7298 ConvertAvgPool2DGradOp,
7299 ConvertAvgPool3DGradOp,
7300 ConvertMaxPool2DOp,
7301 ConvertMaxPool3DOp,
7302 ConvertMaxPool2DGradOp,
7303 ConvertMaxPool3DGradOp,
7304 ConvertMeanOp,
7305 ConvertOneHotOp,
7306 ConvertOutfeedEnqueueTupleOp,
7307 ConvertProdOp,
7308 ConvertQrOp,
7309 ConvertDynamicRangeOp,
7310 ConvertMatrixDiagPartV3Op,
7311 ConvertRangeOp,
7312 ConvertSelectOp,
7313 ConvertSigmoidOp,
7314 ConvertShapeOp,
7315 ConvertSplitOp,
7316 ConvertSplitVOp,
7317 ConvertStridedSliceOp,
7318 ConvertStridedSliceGradOp,
7319 ConvertSumOp,
7320 ConvertTensorScatterAddOp,
7321 ConvertTensorScatterSubOp,
7322 ConvertTensorScatterMinOp,
7323 ConvertTensorScatterMaxOp,
7324 ConvertTensorScatterUpdateOp,
7325 ConvertTileOp,
7326 ConvertTopKV2Op,
7327 ConvertUnpackOp,
7328 ConvertUnsortedSegmentMaxOp,
7329 ConvertUnsortedSegmentMinOp,
7330 ConvertUnsortedSegmentProdOp,
7331 ConvertUnsortedSegmentSumOp,
7332 ConvertRandomShuffleOp,
7333 ConvertXlaShardingOp,
7334 ConvertXlaDynamicUpdateSliceOp,
7335 ConvertXlaConvV2Op,
7336 ConvertXlaReducePrecisionOp,
7337 ConvertXlaReduceScatterOp,
7338 ConvertXlaReduceWindowOp,
7339 ConvertXlaRngBitGeneratorOp,
7340 ConvertXlaSelectAndScatterOp,
7341 ConvertXlaSortOp,
7342 ConvertXlaVariadicReduceV2Op,
7343 ConvertXlaVariadicSortOp,
7344 ConvertRollOp,
7345 ConvertLeakyReluOp,
7346 ConvertLeakyReluGradOp,
7347 ConvertSplitOpDynamic,
7348 ConvertSliceOpDynamic,
7349 ConvertTileOpDynamic,
7350 ConvertUnpackOpDynamic,
7351 ConvertSigmoidGradOpDynamic,
7352 ConvertConv2DDynamic,
7353 ConvertPadOpDynamic,
7354 ConvertGatherNdOpDynamic>(context);
7355 // clang-format on
7356 }
7357
7358 } // end namespace mhlo
7359 } // end namespace mlir
7360