1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
17
18 #include <numeric>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/IR/Diagnostics.h" // from @llvm-project
26 #include "mlir/IR/MLIRContext.h" // from @llvm-project
27 #include "mlir/IR/Matchers.h" // from @llvm-project
28 #include "mlir/IR/PatternMatch.h" // from @llvm-project
29 #include "mlir/IR/TypeRange.h" // from @llvm-project
30 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
37 #include "tensorflow/core/util/tensor_format.h"
38
39 namespace mlir {
40 namespace TF {
41 namespace {
42
43 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)44 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
45 Builder *builder) {
46 RankedTensorType ty = RankedTensorType::get(
47 {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
48 return DenseIntElementsAttr::get(ty, values);
49 }
50
51 // Returns a 1-d i64 elements attribute populated with numbers from start to
52 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)53 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
54 Builder *builder) {
55 int size = end - start;
56
57 SmallVector<int64_t, 4> vals;
58 vals.resize(size);
59 std::iota(vals.begin(), vals.end(), start);
60
61 TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
62 return DenseIntElementsAttr::get(ty, vals);
63 }
64
65 // Return an Attr representation of the value.
GetF32Scalar(OpBuilder * builder,float value)66 static DenseElementsAttr GetF32Scalar(OpBuilder *builder, float value) {
67 return DenseElementsAttr::get(
68 RankedTensorType::get({}, builder->getF32Type()),
69 FloatAttr::get(builder->getF32Type(), value));
70 }
71
72 // Returns a TF_CastOp to F32. This function is used for CastOps that are
73 // intermediate nodes in a TableGen pattern result. In such a case, the
74 // destination type is not inferred and must be given explicitly.
75 //
76 // Preconditions: The given value must have a ShapedType.
CreateTFCastOpF32(OpBuilder * builder,Location loc,Value x,BoolAttr truncate)77 static Value CreateTFCastOpF32(OpBuilder *builder, Location loc, Value x,
78 BoolAttr truncate) {
79 auto x_type = x.getType().dyn_cast_or_null<ShapedType>();
80 if (!x_type) llvm_unreachable("unsupported type");
81 Type type = x_type.clone(builder->getF32Type());
82 return builder->create<CastOp>(loc, type, x, truncate);
83 }
84
85 // Returns a TF_CastOp to I32. This function is used for CastOps that are
86 // intermediate nodes in a TableGen pattern result. In such a case, the
87 // destination type is not inferred and must be given explicitly.
88 //
89 // Preconditions: The given value must have a ShapedType.
CreateTFCastOpI32(OpBuilder * builder,Location loc,Value x,BoolAttr truncate)90 static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x,
91 BoolAttr truncate) {
92 auto x_type = x.getType().dyn_cast_or_null<ShapedType>();
93 if (!x_type) llvm_unreachable("unsupported type");
94 Type type = x_type.clone(builder->getI32Type());
95 return builder->create<CastOp>(loc, type, x, truncate);
96 }
97
ConvertToAPFloat(double val,Type type)98 static APFloat ConvertToAPFloat(double val, Type type) {
99 if (type.getIntOrFloatBitWidth() == 32) {
100 return APFloat(static_cast<float>(val));
101 }
102
103 return APFloat(val);
104 }
105
106 // Return true if the passed quantized type is unsigned.
QuantizedTypeIsUnsigned(Type type)107 bool QuantizedTypeIsUnsigned(Type type) {
108 return TypeSwitch<Type, bool>(type)
109 .Case<mlir::TF::Qint8Type>([](Type) { return false; })
110 .Case<mlir::TF::Qint16Type>([](Type) { return false; })
111 .Case<mlir::TF::Qint32Type>([](Type) { return false; })
112 .Case<mlir::TF::Quint8Type>([](Type) { return true; })
113 .Case<mlir::TF::Quint16Type>([](Type) { return true; })
114 .Default([](Type) {
115 llvm_unreachable("QuantizedTypeIsUnsigned: not a quantized type");
116 return false;
117 });
118 }
119
120 // Return the half_range value that is used by DequantizeOp. half_range is used
121 // to offset the quantized representation before it gets scaled. In the case
122 // of negative quantize types, this offset is half the type's range.
DequantizeHalfRange(OpBuilder * builder,Value input)123 static DenseElementsAttr DequantizeHalfRange(OpBuilder *builder, Value input) {
124 auto input_type = input.getType().dyn_cast_or_null<ShapedType>();
125 if (!input_type) llvm_unreachable("DequantizeHalfRange: not a ShapedType");
126 bool is_unsigned = QuantizedTypeIsUnsigned(input_type.getElementType());
127 float half_range = is_unsigned ? 0 : 128;
128 return GetScalarOfType(builder->getF32Type(), half_range);
129 }
130
131 // Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum
132 // op.
GetBiasAddGradReductionIndices(int64_t rank,StringAttr data_format,Builder * builder)133 DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank,
134 StringAttr data_format,
135 Builder *builder) {
136 tensorflow::TensorFormat format;
137 if (!FormatFromString(data_format.getValue().str(), &format)) return {};
138
139 // Reduce along all dimensions except the feature dimension.
140 int64_t feature_dim = GetTensorFeatureDimIndex(rank, format);
141 llvm::SmallVector<int64_t, 4> dims_to_reduce(rank - 1);
142 std::iota(dims_to_reduce.begin(), dims_to_reduce.begin() + feature_dim, 0);
143 std::iota(dims_to_reduce.begin() + feature_dim, dims_to_reduce.end(),
144 feature_dim + 1);
145 return GetI64ElementsAttr(dims_to_reduce, builder);
146 }
147
148 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_lower_tf.inc"
149
150 // Infers ExpandDims op output type for the given input type `ty` and dimension
151 // to expand at the given `axis`.
InferExpandDimsType(Type ty,int64_t axis,Builder * builder)152 Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
153 auto ranked_ty = ty.dyn_cast<RankedTensorType>();
154
155 // Unranked type.
156 if (!ranked_ty) return ty;
157
158 auto shape = llvm::to_vector<4>(ranked_ty.getShape());
159 if (axis < 0) axis += ranked_ty.getRank() + 1;
160
161 shape.insert(shape.begin() + axis, 1);
162 return RankedTensorType::get(shape, ranked_ty.getElementType());
163 }
164
165 // Converts individual Values to a tensor of rank 1. Each input Value has rank 1
166 // and size 1.
ValuesToRank1(PatternRewriter & rewriter,Location loc,Type dtype,ArrayRef<Value> vals)167 Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype,
168 ArrayRef<Value> vals) {
169 int64_t length = vals.size();
170 auto type = RankedTensorType::get({length}, dtype);
171 auto axis = rewriter.create<ConstOp>(
172 loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
173 return rewriter.create<ConcatV2Op>(loc, type, ValueRange(vals), axis);
174 }
175
176 // Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
177 //
178 // Note that to improve the parallelism, AddN op uses tree-based reduction.
179 // For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows:
180 //
181 // 0 1 2 3 4
182 // | | | | |
183 // ------- ------- |
184 // | | |
185 // 5 6 |
186 // | | |
187 // ------------- |
188 // | |
189 // 7 |
190 // | |
191 // ----------------
192 // |
193 // 8
194 //
195 // Example:
196 //
197 // %result = "tf.AddN"(%0, %1, %2)
198 //
199 // is lowered to:
200 //
201 // %sum0 = "tf.AddV2"(%0, %1)
202 // %result = "tf.AddV2"(%sum0, %2)
203 //
204 // While
205 //
206 // %result = "tf.AddN"(%0, %1, %2, %3, %4)
207 //
208 // is lowered to:
209 //
210 // %sum0 = "tf.AddV2"(%0, %1)
211 // %sum1 = "tf.AddV2"(%2, %3)
212 // %sum2 = "tf.AddV2"(%sum0, %sum1)
213 // %result = "tf.AddV2"(%sum2, %4)
214 //
215 class LowerAddNOp : public RewritePattern {
216 public:
LowerAddNOp(MLIRContext * context)217 explicit LowerAddNOp(MLIRContext *context)
218 : RewritePattern(AddNOp::getOperationName(), 1, context,
219 {AddV2Op::getOperationName()}) {}
220
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const221 LogicalResult matchAndRewrite(Operation *op,
222 PatternRewriter &rewriter) const override {
223 auto addn_op = cast<AddNOp>(op);
224
225 // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't
226 // support variant type so variant types require special handling.
227 if (getElementTypeOrSelf(addn_op.getType()).isa<VariantType>())
228 return failure();
229 llvm::SmallVector<Value, 4> operands(addn_op.inputs().begin(),
230 addn_op.inputs().end());
231
232 int64_t n = operands.size();
233 // Keep doing tree-based reduction when there are more than one operand.
234 while (n > 1) {
235 for (int64_t i = 0; i < n; i += 2) {
236 // Add two adjacent operands if applicable.
237 operands[i / 2] =
238 (i + 1 < n) ? rewriter.create<AddV2Op>(addn_op.getLoc(),
239 operands[i], operands[i + 1])
240 : operands[i];
241 }
242 n = (n + 1) / 2;
243 }
244
245 rewriter.replaceOp(addn_op, operands[0]);
246 return success();
247 }
248 };
249
250 // Lowers DynamicStitch op with constant indices and with static input and
251 // output shapes using Reshape, UnPack and Pack op.
252 //
253 // %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>}
254 // %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> :
255 // tensor<2x2xi32>} %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0,
256 // %arg1)
257 // : (tensor<i32>, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>)
258 // -> tensor<5x2xf32>
259 //
260 // is lowered to
261 //
262 // %shape = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>}
263 // %inp0 = "tf.Reshape"(%arg0, %shape)
264 // : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32>
265 // %inp1 = "tf.Reshape"(%arg1, %shape)
266 // : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
267 // %items0 = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64}
268 // : (tensor<1x2xf32>) -> tensor<2xf32>
269 // %items1:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64}
270 // : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
271 // tensor<2xf32>)
272 // %axis = "tf.Const"() {value = dense<0> : tensor<i64>}
273 // %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
274 // : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
275 // tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
276 //
277 template <typename OpT>
278 class LowerDynamicStitchOp : public RewritePattern {
279 public:
LowerDynamicStitchOp(MLIRContext * context)280 explicit LowerDynamicStitchOp(MLIRContext *context)
281 : RewritePattern(
282 OpT::getOperationName(), 1, context,
283 {ConstOp::getOperationName(), ReshapeOp::getOperationName(),
284 UnpackOp::getOperationName(), PackOp::getOperationName()}) {}
285
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const286 LogicalResult matchAndRewrite(Operation *src_op,
287 PatternRewriter &rewriter) const override {
288 auto op = cast<OpT>(src_op);
289
290 // Static output type is used to compute intermediate values. Note that the
291 // output type doesn't have to be static but if input types and indices are
292 // constant, then the output type can be statically determined.
293 RankedTensorType out_ty =
294 op.getType().template dyn_cast<RankedTensorType>();
295 if (!out_ty || !out_ty.hasStaticShape()) return failure();
296
297 // Extract out all the constant indices' attributes and verify that data
298 // types are static.
299 SmallVector<DenseIntElementsAttr, 4> indices;
300 indices.reserve(op.N());
301 for (auto it : llvm::zip(op.indices(), op.data())) {
302 Value index = std::get<0>(it);
303 Value data = std::get<1>(it);
304
305 DenseIntElementsAttr index_attr;
306 if (!matchPattern(index, m_Constant(&index_attr))) return failure();
307 indices.push_back(index_attr);
308
309 RankedTensorType data_ty =
310 data.getType().template dyn_cast<RankedTensorType>();
311 if (!data_ty || !data_ty.hasStaticShape()) return failure();
312 }
313
314 // Compute type of each of the items and shape to use while reshaping inputs
315 // so that they can be unpacked to extract out individual items.
316 ArrayRef<int64_t> item_shape = out_ty.getShape().drop_front(1);
317 auto item_ty = RankedTensorType::get(item_shape, out_ty.getElementType());
318
319 SmallVector<int64_t, 4> packed_shape;
320 packed_shape.push_back(-1);
321 packed_shape.append(item_shape.begin(), item_shape.end());
322 Location loc = op.getLoc();
323 auto packed_shape_val = rewriter.create<ConstOp>(
324 loc, GetI64ElementsAttr(packed_shape, &rewriter));
325
326 // Prepare each of the output item by unpacking data and then putting it to
327 // the specified index.
328 SmallVector<Value, 8> values(out_ty.getDimSize(0));
329 for (auto it : llvm::zip(indices, op.data())) {
330 DenseIntElementsAttr index_attr = std::get<0>(it);
331 Value data = std::get<1>(it);
332
333 auto reshaped_data =
334 rewriter.create<ReshapeOp>(loc, data, packed_shape_val);
335 auto num_items = reshaped_data.getType()
336 .template cast<RankedTensorType>()
337 .getShape()[0];
338 auto items = rewriter.create<UnpackOp>(
339 loc, SmallVector<Type, 4>(num_items, item_ty), reshaped_data,
340 /*axis=*/0);
341 for (auto index_item : llvm::zip(index_attr, items.getResults())) {
342 int64_t output_index = std::get<0>(index_item).getSExtValue();
343 Value item = std::get<1>(index_item);
344 values[output_index] = item;
345 }
346 }
347
348 rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values);
349 return success();
350 }
351 };
352
353 // This pass performs a manual conversion with FakeQuant, converting between
354 // floating point and quantized space. It is designed to reproduce TF's
355 // implementation, mirroring the previous XLA implementation.
356 //
357 // 1. Computing proper quantized bounds. This involves nudging the input bounds.
358 // 2. Converting the input bounds to quantized space, rounding values.
359 // 3. Convert back into floating point space.
360 class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
361 public:
ConvertFakeQuantWithMinMaxVarsOp(MLIRContext * context)362 explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
363 : RewritePattern(
364 FakeQuantWithMinMaxVarsOp::getOperationName(), 1, context,
365 {AddV2Op::getOperationName(), SubOp::getOperationName(),
366 ConstOp::getOperationName(), MulOp::getOperationName(),
367 FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
368 DivOp::getOperationName(), RoundOp::getOperationName()}) {}
369
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const370 LogicalResult matchAndRewrite(Operation *src_op,
371 PatternRewriter &rewriter) const override {
372 auto op = cast<FakeQuantWithMinMaxVarsOp>(src_op);
373
374 auto input = op.inputs();
375 auto input_ty = input.getType().cast<ShapedType>();
376 auto element_ty = input_ty.getElementType();
377 auto scalar_ty = RankedTensorType::get({}, element_ty);
378
379 auto num_bits = op.num_bits();
380 auto narrow_range = op.narrow_range();
381 const double bits_min = narrow_range ? 1 : 0;
382 const double bits_max = (1 << num_bits) - 1;
383
384 auto float_min = op.min();
385 auto float_max = op.max();
386
387 auto float_diff = rewriter.create<SubOp>(op.getLoc(), float_max, float_min);
388
389 // Compute the range when quantized.
390 auto quant_min = rewriter.create<ConstOp>(
391 op.getLoc(), DenseElementsAttr::get(
392 scalar_ty, ConvertToAPFloat(bits_min, element_ty)));
393
394 auto quant_max = rewriter.create<ConstOp>(
395 op.getLoc(), DenseElementsAttr::get(
396 scalar_ty, ConvertToAPFloat(bits_max, element_ty)));
397
398 auto quant_diff = rewriter.create<ConstOp>(
399 op.getLoc(),
400 DenseElementsAttr::get(
401 scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty)));
402
403 auto quant_to_float =
404 rewriter.create<DivOp>(op.getLoc(), float_diff, quant_diff);
405
406 auto float_to_quant =
407 rewriter.create<DivOp>(op.getLoc(), quant_diff, float_diff);
408
409 // During quantization, the quantized min/max values may not line up
410 // perfectly with the specified min/max. Nudge them into the right range.
411 auto min_scaled =
412 rewriter.create<DivOp>(op.getLoc(), float_min, quant_to_float);
413 auto min_scaled_sub =
414 rewriter.create<SubOp>(op.getLoc(), quant_min, min_scaled);
415
416 auto mid_rounded =
417 rewriter.create<RoundOp>(op.getLoc(), scalar_ty, min_scaled_sub);
418
419 auto nudged_zero_point_val = rewriter.create<ClipByValueOp>(
420 op.getLoc(), scalar_ty, mid_rounded, quant_min, quant_max);
421
422 auto quant_min_sub =
423 rewriter.create<SubOp>(op.getLoc(), quant_min, nudged_zero_point_val);
424 auto quant_max_sub =
425 rewriter.create<SubOp>(op.getLoc(), quant_max, nudged_zero_point_val);
426
427 auto nudged_float_min =
428 rewriter.create<MulOp>(op.getLoc(), quant_min_sub, quant_to_float);
429
430 auto nudged_float_max =
431 rewriter.create<MulOp>(op.getLoc(), quant_max_sub, quant_to_float);
432
433 // Now quantize the input value with the approximated min/max values.
434
435 // Move the input value into quantized space
436 Value quantized_input = rewriter.create<ClipByValueOp>(
437 op.getLoc(), input_ty, input, nudged_float_min, nudged_float_max);
438
439 quantized_input = rewriter.create<SubOp>(op.getLoc(), input_ty,
440 quantized_input, nudged_float_min);
441
442 quantized_input = rewriter.create<MulOp>(op.getLoc(), input_ty,
443 quantized_input, float_to_quant);
444
445 // Round the quantized input always to the positive direction.
446 auto half_val = rewriter.create<ConstOp>(
447 op.getLoc(),
448 DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
449
450 quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
451 quantized_input, half_val);
452
453 quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
454
455 // Convert back into floating point spae.
456 Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
457 quantized_input, quant_to_float);
458
459 output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
460 nudged_float_min);
461
462 rewriter.replaceOp(op, {output});
463 return success();
464 }
465 };
466
467 // Lowers InvertPermutation op to TensorScatterUpdate op.
468 //
469 // Example:
470 //
471 // %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
472 // "tf.InvertPermutation"(%x) : (tensor<5xi32>) -> tensor<5xi32>
473 //
474 // is lowered to
475 //
476 // %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
477 // %start = "tf.Const"() {value = dense<0> : tensor<i32>}
478 // %limit = "tf.Const"() {value = dense<5> : tensor<i32>}
479 // %delta = "tf.Const"() {value = dense<1> : tensor<i32>}
480 // %updates = "tf.Range"(%start, %limit, %delta) :
481 // (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xi32>
482 // %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>}
483 // %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) ->
484 // tensor<5x1xi32>
485 // "tf.TensorScatterUpdate"(%x, %indices, %updates) :
486 // (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
487 //
488 class LowerInvertPermutationOp : public RewritePattern {
489 public:
LowerInvertPermutationOp(MLIRContext * context)490 explicit LowerInvertPermutationOp(MLIRContext *context)
491 : RewritePattern(
492 InvertPermutationOp::getOperationName(), 1, context,
493 {ConstOp::getOperationName(), RangeOp::getOperationName(),
494 ReshapeOp::getOperationName(),
495 TensorScatterUpdateOp::getOperationName()}) {}
496
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const497 LogicalResult matchAndRewrite(Operation *src_op,
498 PatternRewriter &rewriter) const override {
499 auto op = cast<InvertPermutationOp>(src_op);
500
501 Location loc = op.getLoc();
502 auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
503 // x input must have static shape.
504 if (!x_type || !x_type.hasStaticShape()) {
505 return failure();
506 }
507 Type int_type = x_type.getElementType(); // Could be i32 or i64.
508
509 auto result_type = x_type;
510 auto start = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 0));
511 Value limit = rewriter.create<ConstOp>(
512 loc, GetScalarOfType(int_type, x_type.getShape()[0]));
513 auto delta = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 1));
514 // Construct a sequence of numbers [0, 1, ... len(x)-1].
515 auto updates =
516 rewriter.create<RangeOp>(loc, result_type, start, limit, delta);
517
518 auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32));
519 auto shape = rewriter.create<ConstOp>(
520 loc, DenseElementsAttr::get(
521 shape_type, {static_cast<int>(x_type.getDimSize(0)), 1}));
522 auto indices = rewriter.create<ReshapeOp>(loc, op.x(), shape);
523
524 rewriter.replaceOpWithNewOp<TensorScatterUpdateOp>(op, result_type, op.x(),
525 indices, updates);
526 return success();
527 }
528 };
529
530 // Approximates lgamma using Lanczos' approximation from
531 // "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical
532 // Analysis series B. Vol. 1:
533 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
534 // t(z) = z + kLanczosGamma + 1/2
535 // A(z) = kBaseLanczosCoeff
536 // + sigma(k = 1, n, kLanczosCoefficients[i] / (z + k))
537 //
538 // Coefficients for the Lanczos approximation of the gamma function. The
539 // coefficients are uniquely determined by the choice of g and n
540 // (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below
541 // correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were
542 // evaluated and [7, 9] seemed to be the least sensitive to the quality of the
543 // log function. In particular, [5, 7] is the only choice where -1.5e-5 <=
544 // lgamma(2) <= 1.5e-5 for a particularly inaccurate log function.
545 static constexpr double kLanczosGamma = 7; // aka g
546 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
547 static constexpr std::array<double, 8> kLanczosCoefficients = {
548 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
549 771.3234287776530788486528258894, -176.61502916214059906584551354,
550 12.507343278686904814458936853, -0.13857109526572011689554707,
551 9.984369578019570859563e-6, 1.50563273514931155834e-7};
552
553 class LowerLgammaOp : public RewritePattern {
554 public:
LowerLgammaOp(MLIRContext * context)555 explicit LowerLgammaOp(MLIRContext *context)
556 : RewritePattern(LgammaOp::getOperationName(), 1, context,
557 {
558 CastOp::getOperationName(),
559 ConstOp::getOperationName(),
560 NegOp::getOperationName(),
561 SubOp::getOperationName(),
562 SelectV2Op::getOperationName(),
563 LessOp::getOperationName(),
564 AddV2Op::getOperationName(),
565 DivOp::getOperationName(),
566 SubOp::getOperationName(),
567 LogOp::getOperationName(),
568 Log1pOp::getOperationName(),
569 IsInfOp::getOperationName(),
570 MulOp::getOperationName(),
571 FloorOp::getOperationName(),
572 AbsOp::getOperationName(),
573 GreaterOp::getOperationName(),
574 SinOp::getOperationName(),
575 IsFiniteOp::getOperationName(),
576 }) {}
577
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const578 LogicalResult matchAndRewrite(Operation *src_op,
579 PatternRewriter &rewriter) const override {
580 auto op = cast<LgammaOp>(src_op);
581
582 Location loc = op.getLoc();
583 Value input = op.x();
584 TensorType original_tensor_type = op.x().getType().cast<TensorType>();
585
586 // The approximation is not precise enough for float16. Do the computation
587 // in float32 for that case.
588 TensorType tensor_type = original_tensor_type;
589 FloatType float_type = tensor_type.getElementType().cast<FloatType>();
590 bool needs_cast = float_type.getWidth() < 32;
591 if (needs_cast) {
592 MLIRContext *context = rewriter.getContext();
593 float_type = FloatType::getF32(context);
594 if (original_tensor_type.hasRank()) {
595 tensor_type =
596 RankedTensorType::get(original_tensor_type.getShape(), float_type);
597 } else {
598 tensor_type = UnrankedTensorType::get(float_type);
599 }
600 input = rewriter.create<CastOp>(loc, tensor_type, input);
601 }
602
603 // Helper lambda function for creating a ConstOp for a tensor filled with
604 // the given constant float value.
605 auto create_const_op = [&rewriter, loc, tensor_type,
606 float_type](double value) {
607 return rewriter.create<ConstOp>(
608 loc, DenseElementsAttr::get(tensor_type,
609 FloatAttr::get(float_type, value)));
610 };
611
612 Value one_half = create_const_op(0.5);
613 Value one = create_const_op(1.0);
614 Value infinity = create_const_op(std::numeric_limits<double>::infinity());
615 Value pi = create_const_op(M_PI);
616 Value log_pi = create_const_op(std::log(M_PI));
617 Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2);
618 Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5);
619 Value log_lanczos_gamma_plus_one_half =
620 create_const_op(std::log(kLanczosGamma + 0.5));
621 Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff);
622
623 Value minus_input = rewriter.create<NegOp>(loc, input);
624 Value input_minus_one = rewriter.create<SubOp>(loc, input, one);
625
626 // If the input is less than 0.5 use Euler's reflection formula:
627 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
628 Value need_to_reflect = rewriter.create<LessOp>(loc, input, one_half);
629 Type tensor_bool_type = need_to_reflect.getType();
630 Value z = rewriter.create<SelectV2Op>(loc, need_to_reflect, minus_input,
631 input_minus_one);
632
633 Value x = base_lanczos_coeff;
634 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
635 Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]);
636 Value index = create_const_op(static_cast<double>(i));
637 Value z_plus_index = rewriter.create<AddV2Op>(loc, z, index);
638 Value z_plus_index_plus_one =
639 rewriter.create<AddV2Op>(loc, z_plus_index, one);
640 Value incr = rewriter.create<DivOp>(loc, lanczos_coefficient,
641 z_plus_index_plus_one);
642 x = rewriter.create<AddV2Op>(loc, x, incr);
643 }
644
645 // To improve accuracy on platforms with less-precise log implementations,
646 // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
647 // the device.
648 // log(t) = log(kLanczosGamma + 0.5 + z)
649 // = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
650 Value t = rewriter.create<AddV2Op>(loc, lanczos_gamma_plus_one_half, z);
651 Value z_div_lanczos_gamma_plus_one_half =
652 rewriter.create<DivOp>(loc, z, lanczos_gamma_plus_one_half);
653 Value log1p_z_div_lanczos_gamma_plus_one_half =
654 rewriter.create<Log1pOp>(loc, z_div_lanczos_gamma_plus_one_half);
655 Value log_t =
656 rewriter.create<AddV2Op>(loc, log_lanczos_gamma_plus_one_half,
657 log1p_z_div_lanczos_gamma_plus_one_half);
658
659 // Compute the final result (modulo reflection). t(z) may be large, and we
660 // need to be careful not to overflow to infinity in the first term of
661 //
662 // (z + 1/2) * log(t(z)) - t(z).
663 //
664 // Therefore we compute this as
665 //
666 // (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
667 //
668 // log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
669 Value t_div_log_t = rewriter.create<DivOp>(loc, t, log_t);
670 Value one_half_minus_t_div_log_t =
671 rewriter.create<SubOp>(loc, one_half, t_div_log_t);
672 Value z_plus_one_half_minus_t_div_log_t =
673 rewriter.create<AddV2Op>(loc, z, one_half_minus_t_div_log_t);
674 Value z_plus_one_half_minus_t_div_log_t_mul_log_t =
675 rewriter.create<MulOp>(loc, z_plus_one_half_minus_t_div_log_t, log_t);
676 Value log_x = rewriter.create<LogOp>(loc, x);
677 Value log_y_rhs = rewriter.create<AddV2Op>(
678 loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x);
679 Value log_y = rewriter.create<AddV2Op>(loc, log_sqrt_two_pi, log_y_rhs);
680
681 // Compute the reflected value, used when x < 0.5:
682 //
683 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
684 //
685 // (The abs is because lgamma is the log of the absolute value of the gamma
686 // function.)
687 //
688 // We have to be careful when computing the final term above. gamma(x) goes
689 // to +/-inf at every integer x < 0, and this is controlled by the
690 // sin(pi * x) term. The slope is large, so precision is particularly
691 // important.
692 //
693 // Because abs(sin(pi * x)) has period 1, we can equivalently use
694 // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x. This
695 // is more numerically accurate: It doesn't overflow to inf like pi * x can,
696 // and if x is an integer, it evaluates to 0 exactly, which is significant
697 // because we then take the log of this value, and log(0) is inf.
698 //
699 // We don't have a frac(x) primitive in XLA and computing it is tricky, but
700 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
701 // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
702 //
703 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
704 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
705 // [0, 1] is symmetric across the line Y=0.5.
706 Value abs_input = rewriter.create<AbsOp>(loc, input);
707 Value abs_input_floor = rewriter.create<FloorOp>(loc, abs_input);
708 Value abs_frac_input =
709 rewriter.create<SubOp>(loc, abs_input, abs_input_floor);
710
711 // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
712 // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
713 Value one_minus_abs_frac_input =
714 rewriter.create<SubOp>(loc, one, abs_frac_input);
715 Value abs_frac_input_gt_one_half =
716 rewriter.create<GreaterOp>(loc, abs_frac_input, one_half);
717 Value reduced_frac_input =
718 rewriter.create<SelectV2Op>(loc, abs_frac_input_gt_one_half,
719 one_minus_abs_frac_input, abs_frac_input);
720 Value pi_mul_reduced_frac_input =
721 rewriter.create<MulOp>(loc, pi, reduced_frac_input);
722 Value sin_pi_mul_reduced_frac_input =
723 rewriter.create<SinOp>(loc, pi_mul_reduced_frac_input);
724 Value reflection_denom =
725 rewriter.create<LogOp>(loc, sin_pi_mul_reduced_frac_input);
726
727 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
728 // then it "wins" and the result is +/-inf.
729 Value is_finite =
730 rewriter.create<IsFiniteOp>(loc, tensor_bool_type, reflection_denom);
731 Value neg_reflection_denom = rewriter.create<NegOp>(loc, reflection_denom);
732 Value log_pi_minus_reflection_denom =
733 rewriter.create<SubOp>(loc, log_pi, reflection_denom);
734 Value reflection_if_finite =
735 rewriter.create<SubOp>(loc, log_pi_minus_reflection_denom, log_y);
736 Value reflection = rewriter.create<SelectV2Op>(
737 loc, is_finite, reflection_if_finite, neg_reflection_denom);
738
739 Value result =
740 rewriter.create<SelectV2Op>(loc, need_to_reflect, reflection, log_y);
741
742 // lgamma(+/-inf) = +inf.
743 Value is_inf = rewriter.create<IsInfOp>(loc, tensor_bool_type, input);
744 result = rewriter.create<SelectV2Op>(loc, is_inf, infinity, result);
745
746 if (needs_cast) {
747 result = rewriter.create<CastOp>(loc, original_tensor_type, result);
748 }
749
750 rewriter.replaceOp(op, result);
751 return success();
752 }
753 };
754
755 // Lowers Pack op to ConcatV2 op after changing shape of the inputs with
756 // ExpandDims op.
757 //
758 // Sample result with 2 inputs to pack:
759 //
760 // %axis = "tf.Const"() {value = dense<1> : tensor<i64>}
761 // %inp0 = "tf.ExpandDims"(%operand0, %axis): tensor<2xf32> -> tensor<2x1xf32>
762 // %inp1 = "tf.ExpandDims"(%operand1, %axis): tensor<2xf32> -> tensor<2x1xf32>
763 // %result = "tf.ConcatV2"(%operand0, %operand1, %axis) { N = 2 : i64 }:
764 //
765 class LowerPackOp : public RewritePattern {
766 public:
LowerPackOp(MLIRContext * context)767 explicit LowerPackOp(MLIRContext *context)
768 : RewritePattern(
769 PackOp::getOperationName(), 1, context,
770 {ConstOp::getOperationName(), ConcatV2Op::getOperationName(),
771 ExpandDimsOp::getOperationName()}) {}
772
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const773 LogicalResult matchAndRewrite(Operation *src_op,
774 PatternRewriter &rewriter) const override {
775 auto op = cast<PackOp>(src_op);
776
777 Location loc = op.getLoc();
778 auto axis_value = rewriter.create<ConstOp>(
779 loc,
780 DenseElementsAttr::get(
781 RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis()));
782 int64_t axis = op.axis();
783
784 Type prev_input_ty, inferred_ty;
785 SmallVector<Value, 4> expanded_inputs;
786 expanded_inputs.reserve(op.N());
787 for (Value input : op.values()) {
788 // If input type is different than the previous input type, infer the
789 // output type. Otherwise, use the already inferred output type from the
790 // previous iteration.
791 Type input_ty = input.getType();
792 if (input_ty != prev_input_ty) {
793 inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
794 prev_input_ty = input_ty;
795 }
796 expanded_inputs.push_back(
797 rewriter.create<ExpandDimsOp>(loc, inferred_ty, input, axis_value));
798 }
799
800 rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), expanded_inputs,
801 axis_value);
802 return success();
803 }
804 };
805
806 // Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))).
807 //
808 // Before rewrite:
809 // output = SpaceToBatchND(input, block_shape, paddings)
810 // Let:
811 // [batch] + spatial_shape + remaining_shape = input.shape
812 // M = spatial_shape.rank
813 // After rewrite:
814 // padded = zero-pad input with paddings
815 // The spatial_shape component of input.shape pads with paddings[*, 0]
816 // before each dimension, and paddings[*, 1] after each dimension.
817 // reshaped = reshape padded to:
818 // [batch]
819 // + [padded.shape[1]/block_shape[0], block_shape[0], ...,
820 // padded.shape[M]/block_shape[M-1], block_shape[M-1]]
821 // + remaining_shape
822 // permuted = transpose reshaped to:
823 // block_shape
824 // + [batch]
825 // + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
826 // + remaining_shape
827 // result = reshape permuted to:
828 // [batch * product(block_shape)]
829 // + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
830 // + remaining_shape
831 class LowerSpaceToBatchNDOp : public RewritePattern {
832 public:
LowerSpaceToBatchNDOp(MLIRContext * context)833 explicit LowerSpaceToBatchNDOp(MLIRContext *context)
834 : RewritePattern(SpaceToBatchNDOp::getOperationName(), 1, context,
835 {
836 CastOp::getOperationName(),
837 ConstOp::getOperationName(),
838 ConcatV2Op::getOperationName(),
839 AddV2Op::getOperationName(),
840 PadOp::getOperationName(),
841 SplitOp::getOperationName(),
842 UnpackOp::getOperationName(),
843 DivOp::getOperationName(),
844 MulOp::getOperationName(),
845 ReshapeOp::getOperationName(),
846 TransposeOp::getOperationName(),
847 }) {}
848
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const849 LogicalResult matchAndRewrite(Operation *src_op,
850 PatternRewriter &rewriter) const override {
851 auto op = cast<SpaceToBatchNDOp>(src_op);
852
853 Location loc = op.getLoc();
854 auto input_type = op.input().getType().cast<TensorType>();
855 auto element_type = input_type.getElementType();
856 if (!input_type.hasStaticShape()) {
857 return failure();
858 }
859 ArrayRef<int64_t> input_shape = input_type.getShape();
860 auto block_shape_type = op.block_shape().getType().cast<TensorType>();
861 if (!block_shape_type.hasStaticShape()) {
862 return failure();
863 }
864 auto paddings_type = op.paddings().getType().cast<ShapedType>();
865 if (!paddings_type.hasRank()) {
866 return failure();
867 }
868
869 int64_t input_rank = input_type.getRank();
870 int64_t block_rank = block_shape_type.getNumElements();
871 int64_t remaining_rank = input_rank - 1 - block_rank;
872 if (remaining_rank < 0) {
873 // TODO(b/157475606): Move this check to ::Verify
874 return failure();
875 }
876
877 auto block_shape_i64_type = RankedTensorType::get(
878 block_shape_type.getShape(), rewriter.getIntegerType(64));
879 auto block_shape_i64 =
880 rewriter.create<CastOp>(loc, block_shape_i64_type, op.block_shape());
881
882 auto paddings_i64_type = RankedTensorType::get(paddings_type.getShape(),
883 rewriter.getIntegerType(64));
884 auto paddings_i64 =
885 rewriter.create<CastOp>(loc, paddings_i64_type, op.paddings());
886
887 auto pad00 = rewriter.create<ConstOp>(
888 loc, DenseElementsAttr::get<int64_t>(
889 RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)),
890 {0, 0}));
891 SmallVector<Value, 4> full_paddings_list{pad00, paddings_i64};
892 full_paddings_list.append(remaining_rank, pad00);
893 auto full_paddings_type =
894 RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64));
895 auto zero_i64 = rewriter.create<ConstOp>(
896 loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
897 // Extends paddings to all dimensions of input by adding 0s to non-block
898 // dimensions.
899 auto full_paddings = rewriter.create<ConcatV2Op>(
900 loc, full_paddings_type, full_paddings_list, zero_i64);
901
902 // Compute the result type here instead of using shape inference because the
903 // full_paddings won't be available as a constant for shape inference.
904 ElementsAttr block_shape;
905 ElementsAttr paddings;
906 llvm::SmallVector<int64_t, 4> block_shape_ints;
907 auto padded_shape = llvm::to_vector<4>(input_shape);
908 if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
909 matchPattern(op.paddings(), m_Constant(&paddings))) {
910 for (uint64_t i = 0; i < block_rank; i++) {
911 int64_t paddings_sum =
912 paddings.getValues<APInt>()[{i, 0}].getSExtValue() +
913 paddings.getValues<APInt>()[{i, 1}].getSExtValue();
914 int64_t block_shape_i =
915 block_shape.getValues<APInt>()[i].getSExtValue();
916 padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]);
917 block_shape_ints.push_back(block_shape_i);
918 }
919 } else {
920 for (int i = 0; i < block_rank; i++) {
921 padded_shape[i + 1] = ShapedType::kDynamicSize;
922 }
923 block_shape_ints.resize(block_shape_type.getNumElements(), -1);
924 }
925
926 auto padded_type = RankedTensorType::get(padded_shape, element_type);
927 // padded = pad(input, full_paddings)
928 auto padded =
929 rewriter.create<PadOp>(loc, padded_type, op.input(), full_paddings);
930
931 auto paddings_sum_type =
932 RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
933 // paddings_sum = paddings[*,0] + paddings[*,1]
934 auto paddings_split = rewriter.create<UnpackOp>(
935 loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
936 rewriter.getI64IntegerAttr(1));
937 auto paddings_sum = rewriter.create<AddV2Op>(
938 loc, paddings_split.getResult(0), paddings_split.getResult(1));
939
940 auto input_shape_tensor = rewriter.create<ConstOp>(
941 loc,
942 DenseElementsAttr::get(
943 RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)),
944 input_shape));
945
946 // padded_shape_tensor is the shape of padded.
947 auto padded_shape_tensor =
948 rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
949
950 auto zero_i32 = rewriter.create<ConstOp>(
951 loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
952 SmallVector<Type, 4> padded_shape_splits_types(
953 input_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
954 SmallVector<Value, 4> padded_shape_splits(
955 rewriter
956 .create<SplitOp>(loc, padded_shape_splits_types, zero_i32,
957 padded_shape_tensor)
958 .output());
959
960 SmallVector<Type, 4> block_shape_splits_types(
961 block_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
962 SmallVector<Value, 4> block_shape_splits(
963 rewriter
964 .create<SplitOp>(loc, block_shape_splits_types, zero_i32,
965 block_shape_i64)
966 .output());
967
968 SmallVector<int64_t, 4> outer_shape_ints;
969 SmallVector<Value, 4> outer_shape_vals;
970 for (int64_t i = 0; i < block_rank; ++i) {
971 // TODO(b/157475606): Insert tf.Assert that the following division has
972 // remainder 0.
973 outer_shape_vals.push_back(rewriter.create<DivOp>(
974 loc, padded_shape_splits[1 + i], block_shape_splits[i]));
975
976 auto padded_shape_i = padded_shape[1 + i];
977 auto block_shape_ints_i = block_shape_ints[i];
978
979 // Compute the outer_shape constant values to infer the reshape.
980 if (padded_shape_i == -1 || block_shape_ints_i == -1) {
981 outer_shape_ints.push_back(-1);
982 } else {
983 outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i);
984 }
985 }
986
987 SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
988 SmallVector<int64_t, 6> reshaped_shape_ints{padded_shape[0]};
989 for (int64_t i = 0; i < block_rank; ++i) {
990 reshaped_shape_vals.push_back(outer_shape_vals[i]);
991 reshaped_shape_vals.push_back(block_shape_splits[i]);
992
993 reshaped_shape_ints.push_back(outer_shape_ints[i]);
994 reshaped_shape_ints.push_back(block_shape_ints[i]);
995 }
996 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
997 reshaped_shape_vals.push_back(padded_shape_splits[i]);
998 reshaped_shape_ints.push_back(padded_shape[i]);
999 }
1000 auto reshaped_shape = ValuesToRank1(
1001 rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
1002
1003 auto reshaped = rewriter.create<ReshapeOp>(
1004 loc, RankedTensorType::get(reshaped_shape_ints, element_type), padded,
1005 reshaped_shape);
1006
1007 SmallVector<int64_t, 6> permutation_vals;
1008 for (int64_t i = 0; i < block_rank; ++i) {
1009 permutation_vals.push_back(2 + 2 * i);
1010 }
1011 permutation_vals.push_back(0);
1012 for (int64_t i = 0; i < block_rank; ++i) {
1013 permutation_vals.push_back(1 + 2 * i);
1014 }
1015 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1016 permutation_vals.push_back(block_rank + i);
1017 }
1018 auto permutation = rewriter.create<ConstOp>(
1019 loc, GetI64ElementsAttr(permutation_vals, &rewriter));
1020
1021 auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
1022 auto output_batch = padded_shape_splits[0];
1023 for (int64_t i = 0; i < block_rank; ++i) {
1024 output_batch =
1025 rewriter.create<MulOp>(loc, output_batch, block_shape_splits[i]);
1026 }
1027 SmallVector<Value, 4> output_shape_vals{output_batch};
1028 for (int64_t i = 0; i < block_rank; ++i) {
1029 output_shape_vals.push_back(outer_shape_vals[i]);
1030 }
1031 for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1032 output_shape_vals.push_back(padded_shape_splits[i]);
1033 }
1034 auto output_shape = ValuesToRank1(
1035 rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
1036
1037 // Sometimes the result type is more specific than what the reshape builder
1038 // can infer.
1039 auto result_type = op.getResult().getType();
1040 rewriter.replaceOpWithNewOp<ReshapeOp>(op, result_type, permuted,
1041 output_shape);
1042
1043 return success();
1044 }
1045 };
1046
1047 class LowerBatchToSpaceND : public RewritePattern {
1048 public:
LowerBatchToSpaceND(MLIRContext * context)1049 explicit LowerBatchToSpaceND(MLIRContext *context)
1050 : RewritePattern(BatchToSpaceNDOp::getOperationName(), 1, context,
1051 {
1052 ConstOp::getOperationName(),
1053 ReshapeOp::getOperationName(),
1054 SliceOp::getOperationName(),
1055 TransposeOp::getOperationName(),
1056 }) {}
1057
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1058 LogicalResult matchAndRewrite(Operation *src_op,
1059 PatternRewriter &rewriter) const override {
1060 auto op = cast<BatchToSpaceNDOp>(src_op);
1061 auto input = op.input();
1062 auto input_ty = input.getType().cast<ShapedType>();
1063 auto element_ty = input_ty.getElementType();
1064 if (!input_ty.hasStaticShape()) {
1065 return failure();
1066 }
1067
1068 const int input_rank = input_ty.getRank();
1069 auto input_shape = input_ty.getShape();
1070
1071 DenseIntElementsAttr block_shape;
1072 DenseIntElementsAttr crops;
1073 if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) ||
1074 !matchPattern(op.crops(), m_Constant(&crops))) {
1075 return failure();
1076 }
1077
1078 auto block_shape_ty = block_shape.getType();
1079 if (!block_shape_ty.hasRank() || block_shape_ty.getRank() != 1) {
1080 return failure();
1081 }
1082
1083 const int block_rank = block_shape_ty.getShape().front();
1084 auto remainder_shape = input_shape.drop_front(1 + block_rank);
1085
1086 const int64_t batch_size = input_shape[0];
1087
1088 // Compute the product of the block_shape values.
1089 int64_t block_num_elems = 1;
1090
1091 for (auto val : block_shape.getValues<APInt>()) {
1092 block_num_elems *= val.getSExtValue();
1093 }
1094
1095 if (block_num_elems <= 0) {
1096 op.emitOpError()
1097 << "The product of the block dimensions must be positive";
1098 return failure();
1099 }
1100
1101 // 1. Reshape `input` to `reshaped` of shape:
1102 // [block_shape[0], ..., block_shape[M-1],
1103 // batch / prod(block_shape),
1104 // input_shape[1], ..., input_shape[N-1]]
1105 SmallVector<int64_t> reshaped_shape;
1106 reshaped_shape.reserve(block_shape.size());
1107 for (auto val : block_shape) {
1108 reshaped_shape.push_back(val.getSExtValue());
1109 }
1110 reshaped_shape.resize(input_rank + block_rank);
1111
1112 reshaped_shape[block_rank] = batch_size / block_num_elems;
1113 std::copy(input_shape.begin() + 1, input_shape.end(),
1114 reshaped_shape.begin() + block_rank + 1);
1115
1116 auto reshaped = rewriter.create<TF::ReshapeOp>(
1117 op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input,
1118 rewriter.create<ConstOp>(op.getLoc(),
1119 rewriter.getI64TensorAttr(reshaped_shape)));
1120
1121 // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
1122 // [batch / prod(block_shape),
1123 //
1124 // input_shape[1], block_shape[0],
1125 // ...,
1126 // input_shape[M], block_shape[M-1],
1127 //
1128 // input_shape[M+1], ..., input_shape[N-1]]
1129 SmallVector<int64_t> permutation(reshaped_shape.size());
1130 permutation[0] = block_rank;
1131 for (int i = 0; i < block_rank; ++i) {
1132 permutation[1 + 2 * i] = block_rank + 1 + i;
1133 permutation[1 + 2 * i + 1] = i;
1134 }
1135 std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1136 1 + block_rank * 2);
1137
1138 SmallVector<int64_t> transpose_shape(permutation.size());
1139 for (auto it : llvm::enumerate(permutation)) {
1140 transpose_shape[it.index()] = reshaped_shape[it.value()];
1141 }
1142
1143 auto permuted = rewriter.create<TF::TransposeOp>(
1144 op.getLoc(), RankedTensorType::get(transpose_shape, element_ty),
1145 reshaped,
1146 rewriter.create<ConstOp>(op.getLoc(),
1147 rewriter.getI64TensorAttr(permutation)));
1148
1149 // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
1150 // [batch / prod(block_shape),
1151 //
1152 // input_shape[1] * block_shape[0],
1153 // ...,
1154 // input_shape[M] * block_shape[M-1],
1155 //
1156 // input_shape[M+1],
1157 // ...,
1158 // input_shape[N-1]]
1159 SmallVector<int64_t> reshaped_permuted_shape(input_rank);
1160 auto block_shape_values =
1161 llvm::to_vector<4>(block_shape.getValues<APInt>());
1162 reshaped_permuted_shape[0] = batch_size / block_num_elems;
1163 for (int i = 0; i < block_rank; ++i) {
1164 reshaped_permuted_shape[1 + i] =
1165 block_shape_values[i].getSExtValue() * input_shape[1 + i];
1166 }
1167 std::copy(remainder_shape.begin(), remainder_shape.end(),
1168 reshaped_permuted_shape.begin() + 1 + block_rank);
1169
1170 auto reshaped_permuted = rewriter.create<TF::ReshapeOp>(
1171 op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty),
1172 permuted,
1173 rewriter.create<ConstOp>(
1174 op.getLoc(), rewriter.getI64TensorAttr(reshaped_permuted_shape)));
1175
1176 // 4. Crop the start and end of dimensions `[1, ..., M]` of
1177 // `reshaped_permuted` according to `crops` to produce the output of
1178 // shape:
1179 // [batch / prod(block_shape),
1180 //
1181 // input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
1182 // ...,
1183 // input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
1184 //
1185 // input_shape[M+1], ..., input_shape[N-1]]
1186 SmallVector<int64_t> start_indices(input_rank, 0);
1187 SmallVector<int64_t> slice_sizes = reshaped_permuted_shape;
1188 SmallVector<int64_t> strides(input_rank, 1);
1189 auto crop_values = llvm::to_vector<4>(crops.getValues<APInt>());
1190 for (int i = 0; i < block_rank; ++i) {
1191 int64_t crop_start = crop_values[i * 2].getSExtValue();
1192 int64_t crop_end = crop_values[i * 2 + 1].getSExtValue();
1193
1194 if (crop_start < 0 || crop_end < 0) {
1195 op.emitOpError() << "Crops must be non-negative";
1196 return failure();
1197 }
1198
1199 start_indices[i + 1] = crop_start;
1200 slice_sizes[i + 1] -= crop_start + crop_end;
1201
1202 if (slice_sizes[i + 1] < 0) {
1203 op.emitOpError() << "Cropped size must be non-negative: start: "
1204 << crop_start << " end: " << crop_end << " size "
1205 << reshaped_permuted_shape[1 + i];
1206 }
1207 }
1208
1209 rewriter.replaceOpWithNewOp<TF::SliceOp>(
1210 op, RankedTensorType::get(slice_sizes, element_ty), reshaped_permuted,
1211 rewriter.create<ConstOp>(op.getLoc(),
1212 rewriter.getI64TensorAttr(start_indices)),
1213 rewriter.create<ConstOp>(op.getLoc(),
1214 rewriter.getI64TensorAttr(slice_sizes)));
1215 return success();
1216 }
1217 };
1218
1219 // Lowers `SparseMatMulOp` to `MatMulOp`, ignoring the sparseness hints,
1220 // since we currently don't have an implementation that can use this
1221 // information. Adds appropriate casts where necessary to align element types
1222 // of operands and result for `MatMulOp`.
1223 class LowerSparseMatMulOp : public RewritePattern {
1224 public:
LowerSparseMatMulOp(MLIRContext * context)1225 explicit LowerSparseMatMulOp(MLIRContext *context)
1226 : RewritePattern(
1227 SparseMatMulOp::getOperationName(), 1, context,
1228 {CastOp::getOperationName(), MatMulOp::getOperationName()}) {}
1229
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1230 LogicalResult matchAndRewrite(Operation *src_op,
1231 PatternRewriter &rewriter) const override {
1232 auto op = cast<SparseMatMulOp>(src_op);
1233
1234 // Result type must be f32 for applying the pattern (currently this is
1235 // required by the op anyway but this might change).
1236 if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
1237 return failure();
1238 }
1239 MLIRContext *context = rewriter.getContext();
1240 llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
1241 for (Value &operand : operands) {
1242 TensorType tensor_type = operand.getType().cast<TensorType>();
1243 Type element_type = tensor_type.getElementType();
1244 if (element_type.isF32()) continue;
1245 // Element type can either be f32 or bf16 for `SparseMatMulOp` so it
1246 // must be bf16 here.
1247 assert(element_type.isBF16());
1248 Type tensor_type_f32;
1249 if (tensor_type.hasRank()) {
1250 tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
1251 FloatType::getF32(context));
1252 } else {
1253 tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
1254 }
1255 // Add cast to f32 to conform with element type of result.
1256 operand = rewriter.create<CastOp>(op.getLoc(), tensor_type_f32, operand);
1257 }
1258 Value result = rewriter.create<MatMulOp>(
1259 op.getLoc(), op.product().getType(), operands[0], operands[1],
1260 op.transpose_a(), op.transpose_b());
1261
1262 rewriter.replaceOp(op, {result});
1263 return success();
1264 }
1265 };
1266
1267 // Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that
1268 // were fused together.
1269 class Lower_UnaryOpsComposition
1270 : public OpRewritePattern<_UnaryOpsCompositionOp> {
1271 public:
1272 using OpRewritePattern<_UnaryOpsCompositionOp>::OpRewritePattern;
1273
matchAndRewrite(_UnaryOpsCompositionOp op,PatternRewriter & rewriter) const1274 LogicalResult matchAndRewrite(_UnaryOpsCompositionOp op,
1275 PatternRewriter &rewriter) const override {
1276 Value result = op.x();
1277 for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
1278 std::string full_name = "tf." + op_name.str();
1279 // All ops in the sequences have the same result type as the original
1280 // result type.
1281 OperationState state(op.getLoc(), full_name, /*operands=*/{result},
1282 /*types=*/{op.getType()}, /*attributes=*/{});
1283 Operation *op = rewriter.create(state);
1284 result = op->getResult(0);
1285 }
1286 rewriter.replaceOp(op, {result});
1287 return success();
1288 }
1289 };
1290
1291 // Lowers ResizeNearestNeighbor to an indices computations with a gather along
1292 // the combined spatial dimensions. Generating the indices along the
1293 // width/height index could be used to gather along each of W and H dimension
1294 // of the input image array. To reduce to a single gather, these indices are
1295 // combined, so a single gather can be performed along the combined spatial
1296 // dimensions.
1297 //
1298 // Images must take the shape [b, h, w, c] and size is a rank-1 length-2 tensor
1299 // containing the height and width values for the output tensor. This lowering
1300 // should work with a dynamic images array.
1301 //
1302 // For example, a scaling with image shape [1, 3, 3, 1] to [2, 2] and unaligned
1303 // corners would generate a [0, 1] lookup along both the x and y direction.
1304 // Then when combined to form the 1-D spatial index the values would be
1305 // [0, 1, 3, 4] which would gather along the reshape image tensor of shape
1306 // [1, 9, 1], reshaped to the final [1, 3, 3, 1].
1307 class LowerResizeNearestNeighbor : public RewritePattern {
1308 public:
LowerResizeNearestNeighbor(MLIRContext * context)1309 explicit LowerResizeNearestNeighbor(MLIRContext *context)
1310 : RewritePattern(ResizeNearestNeighborOp::getOperationName(), 1, context,
1311 {
1312 BroadcastToOp::getOperationName(),
1313 ConstOp::getOperationName(),
1314 DivOp::getOperationName(),
1315 PackOp::getOperationName(),
1316 RangeOp::getOperationName(),
1317 ReshapeOp::getOperationName(),
1318 ShapeOp::getOperationName(),
1319 SplitOp::getOperationName(),
1320 TransposeOp::getOperationName(),
1321 }) {}
1322
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1323 LogicalResult matchAndRewrite(Operation *src_op,
1324 PatternRewriter &rewriter) const override {
1325 auto op = cast<ResizeNearestNeighborOp>(src_op);
1326 auto loc = op.getLoc();
1327 auto result_ty = op.getType().cast<ShapedType>();
1328
1329 auto input = op.images();
1330 auto input_ty = input.getType().cast<ShapedType>();
1331 auto input_element_ty = input_ty.getElementType();
1332 auto out_size = op.size();
1333 auto out_size_ty = out_size.getType().cast<ShapedType>();
1334 auto out_size_element_ty = out_size_ty.getElementType();
1335
1336 // Input should be rank 4.
1337 if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1338 return failure();
1339 }
1340
1341 // Check that out_size is rank-1, length-2. Otherwise the size is not legal.
1342 if (!out_size_ty.hasRank() || out_size_ty.getRank() != 1 ||
1343 out_size_ty.getShape()[0] != 2) {
1344 return failure();
1345 }
1346
1347 // Extract the output width / height dim size.
1348 int out_height_constant = -1;
1349 int out_width_constant = -1;
1350 DenseIntElementsAttr out_size_cst;
1351 if (matchPattern(out_size, m_Constant(&out_size_cst))) {
1352 llvm::SmallVector<int64_t, 2> cst_size;
1353 for (auto val : out_size_cst.getValues<APInt>()) {
1354 cst_size.push_back(val.getSExtValue());
1355 }
1356
1357 out_height_constant = cst_size[0];
1358 out_width_constant = cst_size[1];
1359
1360 if (out_height_constant < 0 || out_width_constant < 0) return failure();
1361 }
1362
1363 int out_spatial_cst = out_height_constant < 0 || out_width_constant < 0
1364 ? -1
1365 : out_height_constant * out_width_constant;
1366
1367 // Input rank should be 4. Might be able to drop this requirement entirely
1368 // as its an input requirement.
1369 if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1370 return failure();
1371 }
1372
1373 int batch_cst = input_ty.getShape()[0];
1374 int channels_cst = input_ty.getShape()[3];
1375
1376 int in_y_cst = input_ty.getShape()[1];
1377 int in_x_cst = input_ty.getShape()[2];
1378 int in_spatial_cst =
1379 in_y_cst < 0 || in_x_cst < 0 ? -1 : in_y_cst * in_x_cst;
1380
1381 // TODO(suderman): Add support for these optional parameters.
1382 if (op.align_corners() == true || op.half_pixel_centers() == true) {
1383 return failure();
1384 }
1385
1386 auto one =
1387 rewriter.create<ConstOp>(loc, GetScalarOfType(out_size_element_ty, 1));
1388
1389 // Extract the image shape.
1390 Value input_shape = rewriter.create<ShapeOp>(
1391 loc, RankedTensorType::get({4}, rewriter.getI64Type()), input);
1392 input_shape = rewriter.create<CastOp>(
1393 loc, RankedTensorType::get({4}, out_size_element_ty), input_shape);
1394
1395 auto scalar_dim_ty = RankedTensorType::get({}, out_size_element_ty);
1396 auto split_image_shape = rewriter.create<UnpackOp>(
1397 loc,
1398 TypeRange({scalar_dim_ty, scalar_dim_ty, scalar_dim_ty, scalar_dim_ty}),
1399 input_shape);
1400
1401 // Extract the separate components from the input shape.
1402 auto batch = split_image_shape.getResult(0);
1403 auto in_y = split_image_shape.getResult(1);
1404 auto in_x = split_image_shape.getResult(2);
1405 auto channels = split_image_shape.getResult(3);
1406
1407 auto in_count = rewriter.create<MulOp>(
1408 loc, RankedTensorType::get({}, out_size_element_ty), in_y, in_x);
1409
1410 // Unpack and separate the out width/height.
1411 auto split_out_size = rewriter.create<UnpackOp>(
1412 loc, TypeRange({scalar_dim_ty, scalar_dim_ty}), out_size);
1413
1414 auto out_y = split_out_size.getResult(0);
1415 auto out_x = split_out_size.getResult(1);
1416
1417 auto out_count = rewriter.create<MulOp>(
1418 loc, RankedTensorType::get({}, out_size_element_ty), out_y, out_x);
1419
1420 // Generate what the final output shape will look like.
1421 auto out_shape = rewriter.create<PackOp>(
1422 loc, RankedTensorType::get({4}, out_size_element_ty),
1423 ValueRange({batch, out_y, out_x, channels}));
1424
1425 // Compute the indices along the vertical dimension.
1426 auto in_y_f32 = rewriter.create<CastOp>(
1427 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y);
1428 auto out_w_f32 = rewriter.create<CastOp>(
1429 loc, RankedTensorType::get({}, rewriter.getF32Type()), out_y);
1430
1431 Value y_scale = rewriter.create<DivOp>(
1432 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y_f32,
1433 out_w_f32);
1434
1435 Value zero_f32 = rewriter.create<ConstOp>(
1436 loc, GetScalarOfType(rewriter.getF32Type(), 0.0));
1437 Value one_f32 = rewriter.create<ConstOp>(
1438 loc, GetScalarOfType(rewriter.getF32Type(), 1.0));
1439
1440 Value y_range = rewriter.create<RangeOp>(
1441 loc,
1442 RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1443 zero_f32, out_w_f32, one_f32);
1444
1445 y_range = rewriter.create<MulOp>(
1446 loc,
1447 RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1448 y_range, y_scale);
1449
1450 y_range = rewriter.create<CastOp>(
1451 loc, RankedTensorType::get({out_height_constant}, out_size_element_ty),
1452 y_range);
1453
1454 y_range = rewriter.create<ReshapeOp>(
1455 loc,
1456 RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1457 y_range,
1458 rewriter.create<PackOp>(loc,
1459 RankedTensorType::get({2}, out_size_element_ty),
1460 ValueRange({out_y, one})));
1461
1462 Value y_indices = rewriter.create<MulOp>(
1463 loc,
1464 RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1465 y_range, in_x);
1466
1467 // Compute the indices for the nearest neighbour lookup across the width
1468 // dim.
1469 auto in_x_f32 = rewriter.create<CastOp>(
1470 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x);
1471 auto out_h_f32 = rewriter.create<CastOp>(
1472 loc, RankedTensorType::get({}, rewriter.getF32Type()), out_x);
1473
1474 Value x_scale = rewriter.create<DivOp>(
1475 loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x_f32,
1476 out_h_f32);
1477
1478 Value x_range = rewriter.create<RangeOp>(
1479 loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1480 zero_f32, out_h_f32, one_f32);
1481
1482 x_range = rewriter.create<MulOp>(
1483 loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1484 x_range, x_scale);
1485
1486 x_range = rewriter.create<CastOp>(
1487 loc, RankedTensorType::get({out_width_constant}, out_size_element_ty),
1488 x_range);
1489
1490 Value x_indices = rewriter.create<ReshapeOp>(
1491 loc,
1492 RankedTensorType::get({1, out_width_constant}, out_size_element_ty),
1493 x_range,
1494 rewriter.create<PackOp>(loc,
1495 RankedTensorType::get({2}, out_size_element_ty),
1496 ValueRange({one, out_x})));
1497
1498 // Generate the combined index array, reshape to be 1-D.
1499 Value indices = rewriter.create<AddV2Op>(
1500 loc,
1501 RankedTensorType::get({out_height_constant, out_width_constant},
1502 out_size_element_ty),
1503 y_indices, x_indices);
1504
1505 indices = rewriter.create<ReshapeOp>(
1506 loc, RankedTensorType::get({out_spatial_cst}, out_size_element_ty),
1507 indices,
1508 rewriter.create<ReshapeOp>(
1509 loc, RankedTensorType::get({1}, out_size_element_ty), out_count,
1510 rewriter.create<ConstOp>(loc, rewriter.getI64TensorAttr({1}))));
1511
1512 // Group the spatial indices and gather along that combined index.
1513 Value input_collapsed_spatial = rewriter.create<ReshapeOp>(
1514 loc,
1515 RankedTensorType::get({batch_cst, in_spatial_cst, channels_cst},
1516 input_element_ty),
1517 input,
1518 rewriter.create<PackOp>(loc,
1519 RankedTensorType::get({3}, out_size_element_ty),
1520 ValueRange({batch, in_count, channels})));
1521
1522 Value gathered_values = rewriter.create<GatherV2Op>(
1523 loc,
1524 RankedTensorType::get({batch_cst, out_spatial_cst, channels_cst},
1525 input_element_ty),
1526 input_collapsed_spatial, indices, /*axis=*/one);
1527
1528 gathered_values =
1529 rewriter.create<ReshapeOp>(loc, result_ty, gathered_values, out_shape);
1530
1531 rewriter.replaceOp(op, gathered_values);
1532 return success();
1533 }
1534 };
1535
1536 struct LowerRollOp : public RewritePattern {
LowerRollOpmlir::TF::__anon9aef87ee0111::LowerRollOp1537 explicit LowerRollOp(MLIRContext *context)
1538 : RewritePattern(
1539 RollOp::getOperationName(), 1, context,
1540 {ConstOp::getOperationName(), SliceOp::getOperationName(),
1541 ConcatV2Op::getOperationName()}) {}
1542
matchAndRewritemlir::TF::__anon9aef87ee0111::LowerRollOp1543 LogicalResult matchAndRewrite(Operation *op,
1544 PatternRewriter &rewriter) const override {
1545 auto tf_roll_op = cast<RollOp>(op);
1546
1547 auto input_ty = tf_roll_op.input().getType().dyn_cast<RankedTensorType>();
1548 if (!input_ty || !input_ty.hasStaticShape()) {
1549 return rewriter.notifyMatchFailure(
1550 op, "require the type of input to have static shapes");
1551 }
1552
1553 DenseIntElementsAttr shift_attr;
1554 Value shift = tf_roll_op.shift();
1555 auto shift_ranked_attr_type = shift.getType().dyn_cast<RankedTensorType>();
1556 if (!shift_ranked_attr_type ||
1557 !matchPattern(shift, m_Constant(&shift_attr))) {
1558 return failure();
1559 }
1560
1561 DenseIntElementsAttr axis_attr;
1562 Value axis = tf_roll_op.axis();
1563 auto axis_ranked_attr_type = axis.getType().dyn_cast<RankedTensorType>();
1564 if (!axis_ranked_attr_type || !matchPattern(axis, m_Constant(&axis_attr))) {
1565 return failure();
1566 }
1567
1568 // Combine duplicate axis and make sure they are in [0, rank(input)) range.
1569 auto input_shape = input_ty.getShape();
1570 int input_rank = input_shape.size();
1571 SmallVector<int32_t, 4> shift_map(input_rank, 0);
1572 for (int i = 0; i < axis_attr.getNumElements(); ++i) {
1573 int32_t axis_i = axis_attr.getValues<int32_t>()[i];
1574 if (axis_i < 0) axis_i += input_rank;
1575 int32_t shift_i = shift_attr.getValues<int32_t>()[i];
1576 shift_map[axis_i] += shift_i;
1577 }
1578
1579 SmallVector<int32_t, 4> adjusted_axis;
1580 SmallVector<int32_t, 4> adjusted_shift;
1581 for (int i = 0; i < input_rank; ++i) {
1582 int32_t input_dims_i = input_shape[i];
1583 int32_t shift_i = shift_map[i] % input_dims_i;
1584 if (shift_i < 0) shift_i += input_dims_i;
1585 if (shift_i == 0) continue;
1586 adjusted_axis.push_back(i);
1587 adjusted_shift.push_back(shift_i);
1588 }
1589
1590 // Convert rolling in each dimension to two Slice ops and one Concat op.
1591 auto axis_type =
1592 RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
1593 auto create_slice_op = [&](int32_t axis_i, int32_t begin_i, int32_t size_i,
1594 Value input) {
1595 SmallVector<int64_t, 4> begin_values(input_rank, 0);
1596 begin_values[axis_i] = begin_i;
1597 auto begin_attr = DenseIntElementsAttr::get(axis_type, begin_values);
1598 auto begin =
1599 rewriter.create<ConstOp>(op->getLoc(), axis_type, begin_attr);
1600
1601 SmallVector<int64_t, 4> output_shape;
1602 output_shape.append(input_shape.begin(), input_shape.end());
1603 output_shape[axis_i] = size_i;
1604 auto size_attr = DenseIntElementsAttr::get(axis_type, output_shape);
1605 auto size = rewriter.create<ConstOp>(op->getLoc(), axis_type, size_attr);
1606
1607 auto slice_op_ty =
1608 RankedTensorType::get(output_shape, input_ty.getElementType());
1609 return rewriter.create<SliceOp>(op->getLoc(), slice_op_ty, input, begin,
1610 size);
1611 };
1612
1613 auto result = tf_roll_op.input();
1614 auto scalar_type =
1615 mlir::RankedTensorType::get({}, rewriter.getIntegerType(32));
1616 for (int i = 0; i < adjusted_axis.size(); ++i) {
1617 int32_t axis_i = adjusted_axis[i];
1618 int32_t shift_i = adjusted_shift[i];
1619 auto slice_op_1 = create_slice_op(axis_i, input_shape[axis_i] - shift_i,
1620 shift_i, result);
1621 auto slice_op_2 =
1622 create_slice_op(axis_i, 0, input_shape[axis_i] - shift_i, result);
1623
1624 auto dim_attr = DenseIntElementsAttr::get(scalar_type, {axis_i});
1625 auto concat_dim =
1626 rewriter.create<ConstOp>(op->getLoc(), scalar_type, dim_attr);
1627 auto concat_op = rewriter.create<ConcatV2Op>(
1628 op->getLoc(), input_ty,
1629 ArrayRef<Value>({slice_op_1.output(), slice_op_2.output()}),
1630 concat_dim);
1631 result = concat_op.getResult();
1632 }
1633
1634 rewriter.replaceOp(op, result);
1635 return success();
1636 }
1637 };
1638
1639 // Decomposes Softmax and LogSoftmax to primitive TF ops, using the following
1640 // formulas:
1641 //
1642 // softmax = div(exp(logits), sum(exp(logits)))
1643 // log_softmax = sub(logits, log(sum(exp(logits))))
1644 //
1645 // TODO(jpienaar): Evaluate benefit of templating here.
1646 template <typename OpTy, bool use_log = true>
1647 class LowerSoftmaxOp : public OpRewritePattern<OpTy> {
1648 public:
1649 using OpRewritePattern<OpTy>::OpRewritePattern;
1650
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1651 LogicalResult matchAndRewrite(OpTy op,
1652 PatternRewriter &rewriter) const override {
1653 Value logits = op.logits();
1654 auto loc = op.getLoc();
1655
1656 // Note that the TensorFlow Softmax op verifies that the input rank is
1657 // greater than or equal to one so the following sequence is valid.
1658 auto reduce_dim =
1659 rewriter.create<TF::ConstOp>(loc, GetI64ElementsAttr({-1}, &rewriter));
1660
1661 // Exponential of input values and then their sum can be very large here.
1662 // Division with large denominator is numerically unstable. To improve
1663 // numerical stability, subtract each batch with their max element so that
1664 // the maximum input value is zero. It can be shown that softmax computed
1665 // after adding or subtracting all inputs in a batch using a common value
1666 // gives mathematically equivalent result.
1667 auto max_logits =
1668 rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
1669 /*keep_dims=*/rewriter.getBoolAttr(true));
1670 auto shifted_logits = rewriter.create<TF::SubOp>(loc, logits, max_logits);
1671
1672 // Exponentiate the inputs.
1673 Value exp = rewriter.create<TF::ExpOp>(loc, shifted_logits);
1674
1675 // Compute summation of the exponentials.
1676 Value sum =
1677 rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
1678 /*keep_dims=*/rewriter.getBoolAttr(true));
1679
1680 if (use_log) {
1681 Value log = rewriter.create<TF::LogOp>(loc, sum);
1682 rewriter.replaceOpWithNewOp<TF::SubOp>(op, shifted_logits, log);
1683 } else {
1684 rewriter.replaceOpWithNewOp<TF::DivOp>(op, exp, sum);
1685 }
1686 return success();
1687 }
1688 };
1689
1690 } // namespace
1691
PopulateLoweringTFPatterns(MLIRContext * context,RewritePatternSet * patterns)1692 void PopulateLoweringTFPatterns(MLIRContext *context,
1693 RewritePatternSet *patterns) {
1694 // clang-format off
1695 patterns->add<
1696 LowerAddNOp,
1697 LowerExp1mOp,
1698 ConvertFakeQuantWithMinMaxVarsOp,
1699 LowerDynamicStitchOp<DynamicStitchOp>,
1700 LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1701 LowerInvertPermutationOp,
1702 LowerLgammaOp,
1703 LowerPackOp,
1704 LowerBatchToSpaceND,
1705 LowerSpaceToBatchNDOp,
1706 LowerResizeNearestNeighbor,
1707 LowerSparseMatMulOp,
1708 Lower_UnaryOpsComposition,
1709 LowerRollOp>(context);
1710 // clang-format on
1711 populateWithGenerated(*patterns);
1712 }
1713
PopulateTFLoweringBeforeHLOPatterns(MLIRContext * context,RewritePatternSet * patterns)1714 void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
1715 RewritePatternSet *patterns) {
1716 // clang-format off
1717 patterns->add<
1718 ConvertFakeQuantWithMinMaxVarsOp,
1719 LowerAddNOp,
1720 LowerBatchToSpaceND,
1721 LowerDynamicStitchOp<DynamicStitchOp>,
1722 LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1723 LowerInvertPermutationOp,
1724 LowerPackOp,
1725 LowerResizeNearestNeighbor,
1726 LowerSoftmaxOp<TF::LogSoftmaxOp, /*use_log=*/true>,
1727 LowerSoftmaxOp<TF::SoftmaxOp, /*use_log=*/false>,
1728 LowerSpaceToBatchNDOp,
1729 LowerSparseMatMulOp,
1730 Lower_UnaryOpsComposition,
1731 LowerRollOp>(context);
1732 // clang-format on
1733
1734 // Populate the relevant generated patterns.
1735 // clang-format off
1736 patterns->add<
1737 LowerAddOp,
1738 LowerBiasAddGradOp,
1739 LowerDivNoNanOp,
1740 LowerEmptyOp,
1741 LowerFakeQuantWithMinMaxArgs,
1742 LowerFillOp,
1743 LowerInv,
1744 LowerIsNanOp,
1745 LowerL2LossOp,
1746 LowerMulNoNanOp,
1747 LowerPadOp,
1748 LowerReciprocal,
1749 LowerRintOp,
1750 LowerRoundOpOnFloatTensor,
1751 LowerRoundOpOnIntTensor,
1752 LowerRsqrtGradOp,
1753 LowerScatterNdOp,
1754 LowerSeluOp,
1755 LowerSeluGradOp,
1756 LowerSizeOp,
1757 LowerSoftmaxCrossEntropyWithLogitsOp,
1758 LowerSparseSoftmaxCrossEntropyWithLogitsOp,
1759 LowerSqrtGradOp,
1760 LowerSquareOp,
1761 LowerSquaredDifferenceOpOnRealTensors,
1762 LowerSquaredDifferenceOpOneComplexTensors,
1763 LowerTanhGradOp,
1764 LowerTruncateDivOp,
1765 LowerXdivyOp,
1766 LowerXlog1pyOp,
1767 LowerXlogyOp>(context);
1768 // clang-format on
1769 }
1770
PopulateLoweringQuantizedPatterns(MLIRContext * context,RewritePatternSet * patterns)1771 void PopulateLoweringQuantizedPatterns(MLIRContext *context,
1772 RewritePatternSet *patterns) {
1773 // clang-format off
1774 patterns->add<
1775 LowerDequantizeOp>(context);
1776 // clang-format on
1777 }
1778
1779 } // namespace TF
1780 } // namespace mlir
1781