xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
17 
18 #include <numeric>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
26 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
27 #include "mlir/IR/Matchers.h"  // from @llvm-project
28 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
29 #include "mlir/IR/TypeRange.h"  // from @llvm-project
30 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
37 #include "tensorflow/core/util/tensor_format.h"
38 
39 namespace mlir {
40 namespace TF {
41 namespace {
42 
43 // Returns 1D 64-bit dense elements attribute with the given values.
GetI64ElementsAttr(ArrayRef<int64_t> values,Builder * builder)44 static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
45                                                Builder *builder) {
46   RankedTensorType ty = RankedTensorType::get(
47       {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
48   return DenseIntElementsAttr::get(ty, values);
49 }
50 
51 // Returns a 1-d i64 elements attribute populated with numbers from start to
52 // end, excluding.
GetI64ElementsAttrForSeq(int start,int end,Builder * builder)53 static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end,
54                                                      Builder *builder) {
55   int size = end - start;
56 
57   SmallVector<int64_t, 4> vals;
58   vals.resize(size);
59   std::iota(vals.begin(), vals.end(), start);
60 
61   TensorType ty = RankedTensorType::get({size}, builder->getIntegerType(64));
62   return DenseIntElementsAttr::get(ty, vals);
63 }
64 
65 // Return an Attr representation of the value.
GetF32Scalar(OpBuilder * builder,float value)66 static DenseElementsAttr GetF32Scalar(OpBuilder *builder, float value) {
67   return DenseElementsAttr::get(
68       RankedTensorType::get({}, builder->getF32Type()),
69       FloatAttr::get(builder->getF32Type(), value));
70 }
71 
72 // Returns a TF_CastOp to F32. This function is used for CastOps that are
73 // intermediate nodes in a TableGen pattern result. In such a case, the
74 // destination type is not inferred and must be given explicitly.
75 //
76 // Preconditions: The given value must have a ShapedType.
CreateTFCastOpF32(OpBuilder * builder,Location loc,Value x,BoolAttr truncate)77 static Value CreateTFCastOpF32(OpBuilder *builder, Location loc, Value x,
78                                BoolAttr truncate) {
79   auto x_type = x.getType().dyn_cast_or_null<ShapedType>();
80   if (!x_type) llvm_unreachable("unsupported type");
81   Type type = x_type.clone(builder->getF32Type());
82   return builder->create<CastOp>(loc, type, x, truncate);
83 }
84 
85 // Returns a TF_CastOp to I32. This function is used for CastOps that are
86 // intermediate nodes in a TableGen pattern result. In such a case, the
87 // destination type is not inferred and must be given explicitly.
88 //
89 // Preconditions: The given value must have a ShapedType.
CreateTFCastOpI32(OpBuilder * builder,Location loc,Value x,BoolAttr truncate)90 static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x,
91                                BoolAttr truncate) {
92   auto x_type = x.getType().dyn_cast_or_null<ShapedType>();
93   if (!x_type) llvm_unreachable("unsupported type");
94   Type type = x_type.clone(builder->getI32Type());
95   return builder->create<CastOp>(loc, type, x, truncate);
96 }
97 
ConvertToAPFloat(double val,Type type)98 static APFloat ConvertToAPFloat(double val, Type type) {
99   if (type.getIntOrFloatBitWidth() == 32) {
100     return APFloat(static_cast<float>(val));
101   }
102 
103   return APFloat(val);
104 }
105 
106 // Return true if the passed quantized type is unsigned.
QuantizedTypeIsUnsigned(Type type)107 bool QuantizedTypeIsUnsigned(Type type) {
108   return TypeSwitch<Type, bool>(type)
109       .Case<mlir::TF::Qint8Type>([](Type) { return false; })
110       .Case<mlir::TF::Qint16Type>([](Type) { return false; })
111       .Case<mlir::TF::Qint32Type>([](Type) { return false; })
112       .Case<mlir::TF::Quint8Type>([](Type) { return true; })
113       .Case<mlir::TF::Quint16Type>([](Type) { return true; })
114       .Default([](Type) {
115         llvm_unreachable("QuantizedTypeIsUnsigned: not a quantized type");
116         return false;
117       });
118 }
119 
120 // Return the half_range value that is used by DequantizeOp. half_range is used
121 // to offset the quantized representation before it gets scaled. In the case
122 // of negative quantize types, this offset is half the type's range.
DequantizeHalfRange(OpBuilder * builder,Value input)123 static DenseElementsAttr DequantizeHalfRange(OpBuilder *builder, Value input) {
124   auto input_type = input.getType().dyn_cast_or_null<ShapedType>();
125   if (!input_type) llvm_unreachable("DequantizeHalfRange: not a ShapedType");
126   bool is_unsigned = QuantizedTypeIsUnsigned(input_type.getElementType());
127   float half_range = is_unsigned ? 0 : 128;
128   return GetScalarOfType(builder->getF32Type(), half_range);
129 }
130 
131 // Returns reduction indices to use while lowering tf.BiasAddGrad op to tf.Sum
132 // op.
GetBiasAddGradReductionIndices(int64_t rank,StringAttr data_format,Builder * builder)133 DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank,
134                                                     StringAttr data_format,
135                                                     Builder *builder) {
136   tensorflow::TensorFormat format;
137   if (!FormatFromString(data_format.getValue().str(), &format)) return {};
138 
139   // Reduce along all dimensions except the feature dimension.
140   int64_t feature_dim = GetTensorFeatureDimIndex(rank, format);
141   llvm::SmallVector<int64_t, 4> dims_to_reduce(rank - 1);
142   std::iota(dims_to_reduce.begin(), dims_to_reduce.begin() + feature_dim, 0);
143   std::iota(dims_to_reduce.begin() + feature_dim, dims_to_reduce.end(),
144             feature_dim + 1);
145   return GetI64ElementsAttr(dims_to_reduce, builder);
146 }
147 
148 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_lower_tf.inc"
149 
150 // Infers ExpandDims op output type for the given input type `ty` and dimension
151 // to expand at the given `axis`.
InferExpandDimsType(Type ty,int64_t axis,Builder * builder)152 Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) {
153   auto ranked_ty = ty.dyn_cast<RankedTensorType>();
154 
155   // Unranked type.
156   if (!ranked_ty) return ty;
157 
158   auto shape = llvm::to_vector<4>(ranked_ty.getShape());
159   if (axis < 0) axis += ranked_ty.getRank() + 1;
160 
161   shape.insert(shape.begin() + axis, 1);
162   return RankedTensorType::get(shape, ranked_ty.getElementType());
163 }
164 
165 // Converts individual Values to a tensor of rank 1. Each input Value has rank 1
166 // and size 1.
ValuesToRank1(PatternRewriter & rewriter,Location loc,Type dtype,ArrayRef<Value> vals)167 Value ValuesToRank1(PatternRewriter &rewriter, Location loc, Type dtype,
168                     ArrayRef<Value> vals) {
169   int64_t length = vals.size();
170   auto type = RankedTensorType::get({length}, dtype);
171   auto axis = rewriter.create<ConstOp>(
172       loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
173   return rewriter.create<ConcatV2Op>(loc, type, ValueRange(vals), axis);
174 }
175 
176 // Lowers AddN op to a sequence of AddV2 ops to accumulate operands.
177 //
178 // Note that to improve the parallelism, AddN op uses tree-based reduction.
179 // For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows:
180 //
181 //                 0     1     2     3     4
182 //                 |     |     |     |     |
183 //                 -------     -------     |
184 //                    |           |        |
185 //                    5           6        |
186 //                    |           |        |
187 //                    -------------        |
188 //                          |              |
189 //                          7              |
190 //                          |              |
191 //                          ----------------
192 //                                 |
193 //                                 8
194 //
195 // Example:
196 //
197 //   %result = "tf.AddN"(%0, %1, %2)
198 //
199 // is lowered to:
200 //
201 //   %sum0 = "tf.AddV2"(%0, %1)
202 //   %result = "tf.AddV2"(%sum0, %2)
203 //
204 // While
205 //
206 //   %result = "tf.AddN"(%0, %1, %2, %3, %4)
207 //
208 // is lowered to:
209 //
210 //   %sum0 = "tf.AddV2"(%0, %1)
211 //   %sum1 = "tf.AddV2"(%2, %3)
212 //   %sum2 = "tf.AddV2"(%sum0, %sum1)
213 //   %result = "tf.AddV2"(%sum2, %4)
214 //
215 class LowerAddNOp : public RewritePattern {
216  public:
LowerAddNOp(MLIRContext * context)217   explicit LowerAddNOp(MLIRContext *context)
218       : RewritePattern(AddNOp::getOperationName(), 1, context,
219                        {AddV2Op::getOperationName()}) {}
220 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const221   LogicalResult matchAndRewrite(Operation *op,
222                                 PatternRewriter &rewriter) const override {
223     auto addn_op = cast<AddNOp>(op);
224 
225     // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't
226     // support variant type so variant types require special handling.
227     if (getElementTypeOrSelf(addn_op.getType()).isa<VariantType>())
228       return failure();
229     llvm::SmallVector<Value, 4> operands(addn_op.inputs().begin(),
230                                          addn_op.inputs().end());
231 
232     int64_t n = operands.size();
233     // Keep doing tree-based reduction when there are more than one operand.
234     while (n > 1) {
235       for (int64_t i = 0; i < n; i += 2) {
236         // Add two adjacent operands if applicable.
237         operands[i / 2] =
238             (i + 1 < n) ? rewriter.create<AddV2Op>(addn_op.getLoc(),
239                                                    operands[i], operands[i + 1])
240                         : operands[i];
241       }
242       n = (n + 1) / 2;
243     }
244 
245     rewriter.replaceOp(addn_op, operands[0]);
246     return success();
247   }
248 };
249 
250 // Lowers DynamicStitch op with constant indices and with static input and
251 // output shapes using Reshape, UnPack and Pack op.
252 //
253 //   %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>}
254 //   %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> :
255 //   tensor<2x2xi32>} %0 = "tf.DynamicStitch"(%indices0, %indices1, %arg0,
256 //   %arg1)
257 //     : (tensor<i32>, tensor<2x2xi32>, tensor<2xf32>, tensor<2x2x2xf32>)
258 //     -> tensor<5x2xf32>
259 //
260 // is lowered to
261 //
262 //   %shape = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>}
263 //   %inp0 = "tf.Reshape"(%arg0, %shape)
264 //     : (tensor<2xf32>, tensor<2xi64>) -> tensor<1x2xf32>
265 //   %inp1 = "tf.Reshape"(%arg1, %shape)
266 //     : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
267 //   %items0 = "tf.Unpack"(%[[INP0]]) {axis = 0 : i64}
268 //     : (tensor<1x2xf32>) -> tensor<2xf32>
269 //   %items1:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64}
270 //     : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
271 //     tensor<2xf32>)
272 //   %axis = "tf.Const"() {value = dense<0> : tensor<i64>}
273 //   %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
274 //     : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
275 //        tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
276 //
277 template <typename OpT>
278 class LowerDynamicStitchOp : public RewritePattern {
279  public:
LowerDynamicStitchOp(MLIRContext * context)280   explicit LowerDynamicStitchOp(MLIRContext *context)
281       : RewritePattern(
282             OpT::getOperationName(), 1, context,
283             {ConstOp::getOperationName(), ReshapeOp::getOperationName(),
284              UnpackOp::getOperationName(), PackOp::getOperationName()}) {}
285 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const286   LogicalResult matchAndRewrite(Operation *src_op,
287                                 PatternRewriter &rewriter) const override {
288     auto op = cast<OpT>(src_op);
289 
290     // Static output type is used to compute intermediate values. Note that the
291     // output type doesn't have to be static but if input types and indices are
292     // constant, then the output type can be statically determined.
293     RankedTensorType out_ty =
294         op.getType().template dyn_cast<RankedTensorType>();
295     if (!out_ty || !out_ty.hasStaticShape()) return failure();
296 
297     // Extract out all the constant indices' attributes and verify that data
298     // types are static.
299     SmallVector<DenseIntElementsAttr, 4> indices;
300     indices.reserve(op.N());
301     for (auto it : llvm::zip(op.indices(), op.data())) {
302       Value index = std::get<0>(it);
303       Value data = std::get<1>(it);
304 
305       DenseIntElementsAttr index_attr;
306       if (!matchPattern(index, m_Constant(&index_attr))) return failure();
307       indices.push_back(index_attr);
308 
309       RankedTensorType data_ty =
310           data.getType().template dyn_cast<RankedTensorType>();
311       if (!data_ty || !data_ty.hasStaticShape()) return failure();
312     }
313 
314     // Compute type of each of the items and shape to use while reshaping inputs
315     // so that they can be unpacked to extract out individual items.
316     ArrayRef<int64_t> item_shape = out_ty.getShape().drop_front(1);
317     auto item_ty = RankedTensorType::get(item_shape, out_ty.getElementType());
318 
319     SmallVector<int64_t, 4> packed_shape;
320     packed_shape.push_back(-1);
321     packed_shape.append(item_shape.begin(), item_shape.end());
322     Location loc = op.getLoc();
323     auto packed_shape_val = rewriter.create<ConstOp>(
324         loc, GetI64ElementsAttr(packed_shape, &rewriter));
325 
326     // Prepare each of the output item by unpacking data and then putting it to
327     // the specified index.
328     SmallVector<Value, 8> values(out_ty.getDimSize(0));
329     for (auto it : llvm::zip(indices, op.data())) {
330       DenseIntElementsAttr index_attr = std::get<0>(it);
331       Value data = std::get<1>(it);
332 
333       auto reshaped_data =
334           rewriter.create<ReshapeOp>(loc, data, packed_shape_val);
335       auto num_items = reshaped_data.getType()
336                            .template cast<RankedTensorType>()
337                            .getShape()[0];
338       auto items = rewriter.create<UnpackOp>(
339           loc, SmallVector<Type, 4>(num_items, item_ty), reshaped_data,
340           /*axis=*/0);
341       for (auto index_item : llvm::zip(index_attr, items.getResults())) {
342         int64_t output_index = std::get<0>(index_item).getSExtValue();
343         Value item = std::get<1>(index_item);
344         values[output_index] = item;
345       }
346     }
347 
348     rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values);
349     return success();
350   }
351 };
352 
353 // This pass performs a manual conversion with FakeQuant, converting between
354 // floating point and quantized space. It is designed to reproduce TF's
355 // implementation, mirroring the previous XLA implementation.
356 //
357 // 1. Computing proper quantized bounds. This involves nudging the input bounds.
358 // 2. Converting the input bounds to quantized space, rounding values.
359 // 3. Convert back into floating point space.
360 class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern {
361  public:
ConvertFakeQuantWithMinMaxVarsOp(MLIRContext * context)362   explicit ConvertFakeQuantWithMinMaxVarsOp(MLIRContext *context)
363       : RewritePattern(
364             FakeQuantWithMinMaxVarsOp::getOperationName(), 1, context,
365             {AddV2Op::getOperationName(), SubOp::getOperationName(),
366              ConstOp::getOperationName(), MulOp::getOperationName(),
367              FloorOp::getOperationName(), ClipByValueOp::getOperationName(),
368              DivOp::getOperationName(), RoundOp::getOperationName()}) {}
369 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const370   LogicalResult matchAndRewrite(Operation *src_op,
371                                 PatternRewriter &rewriter) const override {
372     auto op = cast<FakeQuantWithMinMaxVarsOp>(src_op);
373 
374     auto input = op.inputs();
375     auto input_ty = input.getType().cast<ShapedType>();
376     auto element_ty = input_ty.getElementType();
377     auto scalar_ty = RankedTensorType::get({}, element_ty);
378 
379     auto num_bits = op.num_bits();
380     auto narrow_range = op.narrow_range();
381     const double bits_min = narrow_range ? 1 : 0;
382     const double bits_max = (1 << num_bits) - 1;
383 
384     auto float_min = op.min();
385     auto float_max = op.max();
386 
387     auto float_diff = rewriter.create<SubOp>(op.getLoc(), float_max, float_min);
388 
389     // Compute the range when quantized.
390     auto quant_min = rewriter.create<ConstOp>(
391         op.getLoc(), DenseElementsAttr::get(
392                          scalar_ty, ConvertToAPFloat(bits_min, element_ty)));
393 
394     auto quant_max = rewriter.create<ConstOp>(
395         op.getLoc(), DenseElementsAttr::get(
396                          scalar_ty, ConvertToAPFloat(bits_max, element_ty)));
397 
398     auto quant_diff = rewriter.create<ConstOp>(
399         op.getLoc(),
400         DenseElementsAttr::get(
401             scalar_ty, ConvertToAPFloat(bits_max - bits_min, element_ty)));
402 
403     auto quant_to_float =
404         rewriter.create<DivOp>(op.getLoc(), float_diff, quant_diff);
405 
406     auto float_to_quant =
407         rewriter.create<DivOp>(op.getLoc(), quant_diff, float_diff);
408 
409     // During quantization, the quantized min/max values may not line up
410     // perfectly with the specified min/max. Nudge them into the right range.
411     auto min_scaled =
412         rewriter.create<DivOp>(op.getLoc(), float_min, quant_to_float);
413     auto min_scaled_sub =
414         rewriter.create<SubOp>(op.getLoc(), quant_min, min_scaled);
415 
416     auto mid_rounded =
417         rewriter.create<RoundOp>(op.getLoc(), scalar_ty, min_scaled_sub);
418 
419     auto nudged_zero_point_val = rewriter.create<ClipByValueOp>(
420         op.getLoc(), scalar_ty, mid_rounded, quant_min, quant_max);
421 
422     auto quant_min_sub =
423         rewriter.create<SubOp>(op.getLoc(), quant_min, nudged_zero_point_val);
424     auto quant_max_sub =
425         rewriter.create<SubOp>(op.getLoc(), quant_max, nudged_zero_point_val);
426 
427     auto nudged_float_min =
428         rewriter.create<MulOp>(op.getLoc(), quant_min_sub, quant_to_float);
429 
430     auto nudged_float_max =
431         rewriter.create<MulOp>(op.getLoc(), quant_max_sub, quant_to_float);
432 
433     // Now quantize the input value with the approximated min/max values.
434 
435     // Move the input value into quantized space
436     Value quantized_input = rewriter.create<ClipByValueOp>(
437         op.getLoc(), input_ty, input, nudged_float_min, nudged_float_max);
438 
439     quantized_input = rewriter.create<SubOp>(op.getLoc(), input_ty,
440                                              quantized_input, nudged_float_min);
441 
442     quantized_input = rewriter.create<MulOp>(op.getLoc(), input_ty,
443                                              quantized_input, float_to_quant);
444 
445     // Round the quantized input always to the positive direction.
446     auto half_val = rewriter.create<ConstOp>(
447         op.getLoc(),
448         DenseElementsAttr::get(scalar_ty, ConvertToAPFloat(0.5, element_ty)));
449 
450     quantized_input = rewriter.create<AddV2Op>(op.getLoc(), input_ty,
451                                                quantized_input, half_val);
452 
453     quantized_input = rewriter.create<FloorOp>(op.getLoc(), quantized_input);
454 
455     // Convert back into floating point spae.
456     Value output = rewriter.create<MulOp>(op.getLoc(), input_ty,
457                                           quantized_input, quant_to_float);
458 
459     output = rewriter.create<AddV2Op>(op.getLoc(), input_ty, output,
460                                       nudged_float_min);
461 
462     rewriter.replaceOp(op, {output});
463     return success();
464   }
465 };
466 
467 // Lowers InvertPermutation op to TensorScatterUpdate op.
468 //
469 // Example:
470 //
471 //   %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
472 //   "tf.InvertPermutation"(%x) : (tensor<5xi32>) -> tensor<5xi32>
473 //
474 // is lowered to
475 //
476 //   %x = "tf.Const"() {value = dense<[3, 4, 0, 1, 2]> : tensor<5xi32>}
477 //   %start = "tf.Const"() {value = dense<0> : tensor<i32>}
478 //   %limit = "tf.Const"() {value = dense<5> : tensor<i32>}
479 //   %delta = "tf.Const"() {value = dense<1> : tensor<i32>}
480 //   %updates = "tf.Range"(%start, %limit, %delta) :
481 //     (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5xi32>
482 //   %shape = "tf.Const"() {value = dense<[5, 1]> : tensor<2xi32>}
483 //   %indices = "tf.Reshape"(%x, %shape) : (tensor<5xi32, tensor<2xi32) ->
484 //     tensor<5x1xi32>
485 //   "tf.TensorScatterUpdate"(%x, %indices, %updates) :
486 //     (tensor<5xi32>, tensor<5x1xi32>, tensor<5xi32>) -> tensor<5xi32>
487 //
488 class LowerInvertPermutationOp : public RewritePattern {
489  public:
LowerInvertPermutationOp(MLIRContext * context)490   explicit LowerInvertPermutationOp(MLIRContext *context)
491       : RewritePattern(
492             InvertPermutationOp::getOperationName(), 1, context,
493             {ConstOp::getOperationName(), RangeOp::getOperationName(),
494              ReshapeOp::getOperationName(),
495              TensorScatterUpdateOp::getOperationName()}) {}
496 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const497   LogicalResult matchAndRewrite(Operation *src_op,
498                                 PatternRewriter &rewriter) const override {
499     auto op = cast<InvertPermutationOp>(src_op);
500 
501     Location loc = op.getLoc();
502     auto x_type = op.x().getType().dyn_cast<RankedTensorType>();
503     // x input must have static shape.
504     if (!x_type || !x_type.hasStaticShape()) {
505       return failure();
506     }
507     Type int_type = x_type.getElementType();  // Could be i32 or i64.
508 
509     auto result_type = x_type;
510     auto start = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 0));
511     Value limit = rewriter.create<ConstOp>(
512         loc, GetScalarOfType(int_type, x_type.getShape()[0]));
513     auto delta = rewriter.create<ConstOp>(loc, GetScalarOfType(int_type, 1));
514     // Construct a sequence of numbers [0, 1, ... len(x)-1].
515     auto updates =
516         rewriter.create<RangeOp>(loc, result_type, start, limit, delta);
517 
518     auto shape_type = RankedTensorType::get({2}, rewriter.getIntegerType(32));
519     auto shape = rewriter.create<ConstOp>(
520         loc, DenseElementsAttr::get(
521                  shape_type, {static_cast<int>(x_type.getDimSize(0)), 1}));
522     auto indices = rewriter.create<ReshapeOp>(loc, op.x(), shape);
523 
524     rewriter.replaceOpWithNewOp<TensorScatterUpdateOp>(op, result_type, op.x(),
525                                                        indices, updates);
526     return success();
527   }
528 };
529 
530 // Approximates lgamma using Lanczos' approximation from
531 // "A Precision Approximation of the Gamma Function". SIAM Journal on Numerical
532 // Analysis series B. Vol. 1:
533 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + (z + 1/2) * log(t(z)) - t(z) + A(z)
534 // t(z) = z + kLanczosGamma + 1/2
535 // A(z) = kBaseLanczosCoeff
536 //       + sigma(k = 1, n, kLanczosCoefficients[i] / (z +  k))
537 //
538 // Coefficients for the Lanczos approximation of the gamma function. The
539 // coefficients are uniquely determined by the choice of g and n
540 // (kLanczosGamma and kLanczosCoefficients.size() + 1). The coefficients below
541 // correspond to [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were
542 // evaluated and [7, 9] seemed to be the least sensitive to the quality of the
543 // log function. In particular, [5, 7] is the only choice where -1.5e-5 <=
544 // lgamma(2) <= 1.5e-5 for a particularly inaccurate log function.
545 static constexpr double kLanczosGamma = 7;  // aka g
546 static constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
547 static constexpr std::array<double, 8> kLanczosCoefficients = {
548     676.520368121885098567009190444019, -1259.13921672240287047156078755283,
549     771.3234287776530788486528258894,   -176.61502916214059906584551354,
550     12.507343278686904814458936853,     -0.13857109526572011689554707,
551     9.984369578019570859563e-6,         1.50563273514931155834e-7};
552 
553 class LowerLgammaOp : public RewritePattern {
554  public:
LowerLgammaOp(MLIRContext * context)555   explicit LowerLgammaOp(MLIRContext *context)
556       : RewritePattern(LgammaOp::getOperationName(), 1, context,
557                        {
558                            CastOp::getOperationName(),
559                            ConstOp::getOperationName(),
560                            NegOp::getOperationName(),
561                            SubOp::getOperationName(),
562                            SelectV2Op::getOperationName(),
563                            LessOp::getOperationName(),
564                            AddV2Op::getOperationName(),
565                            DivOp::getOperationName(),
566                            SubOp::getOperationName(),
567                            LogOp::getOperationName(),
568                            Log1pOp::getOperationName(),
569                            IsInfOp::getOperationName(),
570                            MulOp::getOperationName(),
571                            FloorOp::getOperationName(),
572                            AbsOp::getOperationName(),
573                            GreaterOp::getOperationName(),
574                            SinOp::getOperationName(),
575                            IsFiniteOp::getOperationName(),
576                        }) {}
577 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const578   LogicalResult matchAndRewrite(Operation *src_op,
579                                 PatternRewriter &rewriter) const override {
580     auto op = cast<LgammaOp>(src_op);
581 
582     Location loc = op.getLoc();
583     Value input = op.x();
584     TensorType original_tensor_type = op.x().getType().cast<TensorType>();
585 
586     // The approximation is not precise enough for float16. Do the computation
587     // in float32 for that case.
588     TensorType tensor_type = original_tensor_type;
589     FloatType float_type = tensor_type.getElementType().cast<FloatType>();
590     bool needs_cast = float_type.getWidth() < 32;
591     if (needs_cast) {
592       MLIRContext *context = rewriter.getContext();
593       float_type = FloatType::getF32(context);
594       if (original_tensor_type.hasRank()) {
595         tensor_type =
596             RankedTensorType::get(original_tensor_type.getShape(), float_type);
597       } else {
598         tensor_type = UnrankedTensorType::get(float_type);
599       }
600       input = rewriter.create<CastOp>(loc, tensor_type, input);
601     }
602 
603     // Helper lambda function for creating a ConstOp for a tensor filled with
604     // the given constant float value.
605     auto create_const_op = [&rewriter, loc, tensor_type,
606                             float_type](double value) {
607       return rewriter.create<ConstOp>(
608           loc, DenseElementsAttr::get(tensor_type,
609                                       FloatAttr::get(float_type, value)));
610     };
611 
612     Value one_half = create_const_op(0.5);
613     Value one = create_const_op(1.0);
614     Value infinity = create_const_op(std::numeric_limits<double>::infinity());
615     Value pi = create_const_op(M_PI);
616     Value log_pi = create_const_op(std::log(M_PI));
617     Value log_sqrt_two_pi = create_const_op((std::log(2) + std::log(M_PI)) / 2);
618     Value lanczos_gamma_plus_one_half = create_const_op(kLanczosGamma + 0.5);
619     Value log_lanczos_gamma_plus_one_half =
620         create_const_op(std::log(kLanczosGamma + 0.5));
621     Value base_lanczos_coeff = create_const_op(kBaseLanczosCoeff);
622 
623     Value minus_input = rewriter.create<NegOp>(loc, input);
624     Value input_minus_one = rewriter.create<SubOp>(loc, input, one);
625 
626     // If the input is less than 0.5 use Euler's reflection formula:
627     // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
628     Value need_to_reflect = rewriter.create<LessOp>(loc, input, one_half);
629     Type tensor_bool_type = need_to_reflect.getType();
630     Value z = rewriter.create<SelectV2Op>(loc, need_to_reflect, minus_input,
631                                           input_minus_one);
632 
633     Value x = base_lanczos_coeff;
634     for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
635       Value lanczos_coefficient = create_const_op(kLanczosCoefficients[i]);
636       Value index = create_const_op(static_cast<double>(i));
637       Value z_plus_index = rewriter.create<AddV2Op>(loc, z, index);
638       Value z_plus_index_plus_one =
639           rewriter.create<AddV2Op>(loc, z_plus_index, one);
640       Value incr = rewriter.create<DivOp>(loc, lanczos_coefficient,
641                                           z_plus_index_plus_one);
642       x = rewriter.create<AddV2Op>(loc, x, incr);
643     }
644 
645     // To improve accuracy on platforms with less-precise log implementations,
646     // compute log(lanczos_gamma_plus_one_half) at compile time and use log1p on
647     // the device.
648     // log(t) = log(kLanczosGamma + 0.5 + z)
649     //        = log(kLanczosGamma + 0.5) + log1p(z / (kLanczosGamma + 0.5))
650     Value t = rewriter.create<AddV2Op>(loc, lanczos_gamma_plus_one_half, z);
651     Value z_div_lanczos_gamma_plus_one_half =
652         rewriter.create<DivOp>(loc, z, lanczos_gamma_plus_one_half);
653     Value log1p_z_div_lanczos_gamma_plus_one_half =
654         rewriter.create<Log1pOp>(loc, z_div_lanczos_gamma_plus_one_half);
655     Value log_t =
656         rewriter.create<AddV2Op>(loc, log_lanczos_gamma_plus_one_half,
657                                  log1p_z_div_lanczos_gamma_plus_one_half);
658 
659     // Compute the final result (modulo reflection).  t(z) may be large, and we
660     // need to be careful not to overflow to infinity in the first term of
661     //
662     //   (z + 1/2) * log(t(z)) - t(z).
663     //
664     // Therefore we compute this as
665     //
666     //   (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
667     //
668     // log_y = log_sqrt_two_pi + (z + one_half - t / log_t) * log_t + Log(x);
669     Value t_div_log_t = rewriter.create<DivOp>(loc, t, log_t);
670     Value one_half_minus_t_div_log_t =
671         rewriter.create<SubOp>(loc, one_half, t_div_log_t);
672     Value z_plus_one_half_minus_t_div_log_t =
673         rewriter.create<AddV2Op>(loc, z, one_half_minus_t_div_log_t);
674     Value z_plus_one_half_minus_t_div_log_t_mul_log_t =
675         rewriter.create<MulOp>(loc, z_plus_one_half_minus_t_div_log_t, log_t);
676     Value log_x = rewriter.create<LogOp>(loc, x);
677     Value log_y_rhs = rewriter.create<AddV2Op>(
678         loc, z_plus_one_half_minus_t_div_log_t_mul_log_t, log_x);
679     Value log_y = rewriter.create<AddV2Op>(loc, log_sqrt_two_pi, log_y_rhs);
680 
681     // Compute the reflected value, used when x < 0.5:
682     //
683     //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
684     //
685     // (The abs is because lgamma is the log of the absolute value of the gamma
686     // function.)
687     //
688     // We have to be careful when computing the final term above. gamma(x) goes
689     // to +/-inf at every integer x < 0, and this is controlled by the
690     // sin(pi * x) term.  The slope is large, so precision is particularly
691     // important.
692     //
693     // Because abs(sin(pi * x)) has period 1, we can equivalently use
694     // abs(sin(pi * frac(x))), where frac(x) is the fractional part of x.  This
695     // is more numerically accurate: It doesn't overflow to inf like pi * x can,
696     // and if x is an integer, it evaluates to 0 exactly, which is significant
697     // because we then take the log of this value, and log(0) is inf.
698     //
699     // We don't have a frac(x) primitive in XLA and computing it is tricky, but
700     // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for
701     // our purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
702     //
703     // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
704     // to 1.  To remedy this, we can use the fact that sin(pi * x) in the domain
705     // [0, 1] is symmetric across the line Y=0.5.
706     Value abs_input = rewriter.create<AbsOp>(loc, input);
707     Value abs_input_floor = rewriter.create<FloorOp>(loc, abs_input);
708     Value abs_frac_input =
709         rewriter.create<SubOp>(loc, abs_input, abs_input_floor);
710 
711     // Convert values of abs_frac_input > 0.5 to (1 - frac_input) to improve
712     // precision of pi * abs_frac_input for values of abs_frac_input close to 1.
713     Value one_minus_abs_frac_input =
714         rewriter.create<SubOp>(loc, one, abs_frac_input);
715     Value abs_frac_input_gt_one_half =
716         rewriter.create<GreaterOp>(loc, abs_frac_input, one_half);
717     Value reduced_frac_input =
718         rewriter.create<SelectV2Op>(loc, abs_frac_input_gt_one_half,
719                                     one_minus_abs_frac_input, abs_frac_input);
720     Value pi_mul_reduced_frac_input =
721         rewriter.create<MulOp>(loc, pi, reduced_frac_input);
722     Value sin_pi_mul_reduced_frac_input =
723         rewriter.create<SinOp>(loc, pi_mul_reduced_frac_input);
724     Value reflection_denom =
725         rewriter.create<LogOp>(loc, sin_pi_mul_reduced_frac_input);
726 
727     // Avoid computing -inf - inf, which is nan.  If reflection_denom is +/-inf,
728     // then it "wins" and the result is +/-inf.
729     Value is_finite =
730         rewriter.create<IsFiniteOp>(loc, tensor_bool_type, reflection_denom);
731     Value neg_reflection_denom = rewriter.create<NegOp>(loc, reflection_denom);
732     Value log_pi_minus_reflection_denom =
733         rewriter.create<SubOp>(loc, log_pi, reflection_denom);
734     Value reflection_if_finite =
735         rewriter.create<SubOp>(loc, log_pi_minus_reflection_denom, log_y);
736     Value reflection = rewriter.create<SelectV2Op>(
737         loc, is_finite, reflection_if_finite, neg_reflection_denom);
738 
739     Value result =
740         rewriter.create<SelectV2Op>(loc, need_to_reflect, reflection, log_y);
741 
742     // lgamma(+/-inf) = +inf.
743     Value is_inf = rewriter.create<IsInfOp>(loc, tensor_bool_type, input);
744     result = rewriter.create<SelectV2Op>(loc, is_inf, infinity, result);
745 
746     if (needs_cast) {
747       result = rewriter.create<CastOp>(loc, original_tensor_type, result);
748     }
749 
750     rewriter.replaceOp(op, result);
751     return success();
752   }
753 };
754 
755 // Lowers Pack op to ConcatV2 op after changing shape of the inputs with
756 // ExpandDims op.
757 //
758 // Sample result with 2 inputs to pack:
759 //
760 //   %axis = "tf.Const"() {value = dense<1> : tensor<i64>}
761 //   %inp0 = "tf.ExpandDims"(%operand0, %axis): tensor<2xf32> -> tensor<2x1xf32>
762 //   %inp1 = "tf.ExpandDims"(%operand1, %axis): tensor<2xf32> -> tensor<2x1xf32>
763 //   %result = "tf.ConcatV2"(%operand0, %operand1, %axis) { N = 2 : i64 }:
764 //
765 class LowerPackOp : public RewritePattern {
766  public:
LowerPackOp(MLIRContext * context)767   explicit LowerPackOp(MLIRContext *context)
768       : RewritePattern(
769             PackOp::getOperationName(), 1, context,
770             {ConstOp::getOperationName(), ConcatV2Op::getOperationName(),
771              ExpandDimsOp::getOperationName()}) {}
772 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const773   LogicalResult matchAndRewrite(Operation *src_op,
774                                 PatternRewriter &rewriter) const override {
775     auto op = cast<PackOp>(src_op);
776 
777     Location loc = op.getLoc();
778     auto axis_value = rewriter.create<ConstOp>(
779         loc,
780         DenseElementsAttr::get(
781             RankedTensorType::get({}, rewriter.getIntegerType(64)), op.axis()));
782     int64_t axis = op.axis();
783 
784     Type prev_input_ty, inferred_ty;
785     SmallVector<Value, 4> expanded_inputs;
786     expanded_inputs.reserve(op.N());
787     for (Value input : op.values()) {
788       // If input type is different than the previous input type, infer the
789       // output type. Otherwise, use the already inferred output type from the
790       // previous iteration.
791       Type input_ty = input.getType();
792       if (input_ty != prev_input_ty) {
793         inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
794         prev_input_ty = input_ty;
795       }
796       expanded_inputs.push_back(
797           rewriter.create<ExpandDimsOp>(loc, inferred_ty, input, axis_value));
798     }
799 
800     rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), expanded_inputs,
801                                             axis_value);
802     return success();
803   }
804 };
805 
806 // Lowers SpaceToBatchND by reducing to reshape(transpose(reshape(pad(input)))).
807 //
808 // Before rewrite:
809 //   output = SpaceToBatchND(input, block_shape, paddings)
810 // Let:
811 //   [batch] + spatial_shape + remaining_shape = input.shape
812 //   M = spatial_shape.rank
813 // After rewrite:
814 //   padded = zero-pad input with paddings
815 //     The spatial_shape component of input.shape pads with paddings[*, 0]
816 //     before each dimension, and paddings[*, 1] after each dimension.
817 //   reshaped = reshape padded to:
818 //     [batch]
819 //     + [padded.shape[1]/block_shape[0], block_shape[0], ...,
820 //        padded.shape[M]/block_shape[M-1], block_shape[M-1]]
821 //     + remaining_shape
822 //   permuted = transpose reshaped to:
823 //     block_shape
824 //     + [batch]
825 //     + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
826 //     + remaining_shape
827 //   result = reshape permuted to:
828 //     [batch * product(block_shape)]
829 //     + [padded.shape[1]/block_shape[0], ..., padded.shape[M]/block_shape[M-1]]
830 //     + remaining_shape
831 class LowerSpaceToBatchNDOp : public RewritePattern {
832  public:
LowerSpaceToBatchNDOp(MLIRContext * context)833   explicit LowerSpaceToBatchNDOp(MLIRContext *context)
834       : RewritePattern(SpaceToBatchNDOp::getOperationName(), 1, context,
835                        {
836                            CastOp::getOperationName(),
837                            ConstOp::getOperationName(),
838                            ConcatV2Op::getOperationName(),
839                            AddV2Op::getOperationName(),
840                            PadOp::getOperationName(),
841                            SplitOp::getOperationName(),
842                            UnpackOp::getOperationName(),
843                            DivOp::getOperationName(),
844                            MulOp::getOperationName(),
845                            ReshapeOp::getOperationName(),
846                            TransposeOp::getOperationName(),
847                        }) {}
848 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const849   LogicalResult matchAndRewrite(Operation *src_op,
850                                 PatternRewriter &rewriter) const override {
851     auto op = cast<SpaceToBatchNDOp>(src_op);
852 
853     Location loc = op.getLoc();
854     auto input_type = op.input().getType().cast<TensorType>();
855     auto element_type = input_type.getElementType();
856     if (!input_type.hasStaticShape()) {
857       return failure();
858     }
859     ArrayRef<int64_t> input_shape = input_type.getShape();
860     auto block_shape_type = op.block_shape().getType().cast<TensorType>();
861     if (!block_shape_type.hasStaticShape()) {
862       return failure();
863     }
864     auto paddings_type = op.paddings().getType().cast<ShapedType>();
865     if (!paddings_type.hasRank()) {
866       return failure();
867     }
868 
869     int64_t input_rank = input_type.getRank();
870     int64_t block_rank = block_shape_type.getNumElements();
871     int64_t remaining_rank = input_rank - 1 - block_rank;
872     if (remaining_rank < 0) {
873       // TODO(b/157475606): Move this check to ::Verify
874       return failure();
875     }
876 
877     auto block_shape_i64_type = RankedTensorType::get(
878         block_shape_type.getShape(), rewriter.getIntegerType(64));
879     auto block_shape_i64 =
880         rewriter.create<CastOp>(loc, block_shape_i64_type, op.block_shape());
881 
882     auto paddings_i64_type = RankedTensorType::get(paddings_type.getShape(),
883                                                    rewriter.getIntegerType(64));
884     auto paddings_i64 =
885         rewriter.create<CastOp>(loc, paddings_i64_type, op.paddings());
886 
887     auto pad00 = rewriter.create<ConstOp>(
888         loc, DenseElementsAttr::get<int64_t>(
889                  RankedTensorType::get({1, 2}, rewriter.getIntegerType(64)),
890                  {0, 0}));
891     SmallVector<Value, 4> full_paddings_list{pad00, paddings_i64};
892     full_paddings_list.append(remaining_rank, pad00);
893     auto full_paddings_type =
894         RankedTensorType::get({input_rank, 2}, rewriter.getIntegerType(64));
895     auto zero_i64 = rewriter.create<ConstOp>(
896         loc, GetScalarOfType(rewriter.getIntegerType(64), 0));
897     // Extends paddings to all dimensions of input by adding 0s to non-block
898     // dimensions.
899     auto full_paddings = rewriter.create<ConcatV2Op>(
900         loc, full_paddings_type, full_paddings_list, zero_i64);
901 
902     // Compute the result type here instead of using shape inference because the
903     // full_paddings won't be available as a constant for shape inference.
904     ElementsAttr block_shape;
905     ElementsAttr paddings;
906     llvm::SmallVector<int64_t, 4> block_shape_ints;
907     auto padded_shape = llvm::to_vector<4>(input_shape);
908     if (matchPattern(op.block_shape(), m_Constant(&block_shape)) &&
909         matchPattern(op.paddings(), m_Constant(&paddings))) {
910       for (uint64_t i = 0; i < block_rank; i++) {
911         int64_t paddings_sum =
912             paddings.getValues<APInt>()[{i, 0}].getSExtValue() +
913             paddings.getValues<APInt>()[{i, 1}].getSExtValue();
914         int64_t block_shape_i =
915             block_shape.getValues<APInt>()[i].getSExtValue();
916         padded_shape[i + 1] = (paddings_sum + input_shape[i + 1]);
917         block_shape_ints.push_back(block_shape_i);
918       }
919     } else {
920       for (int i = 0; i < block_rank; i++) {
921         padded_shape[i + 1] = ShapedType::kDynamicSize;
922       }
923       block_shape_ints.resize(block_shape_type.getNumElements(), -1);
924     }
925 
926     auto padded_type = RankedTensorType::get(padded_shape, element_type);
927     // padded = pad(input, full_paddings)
928     auto padded =
929         rewriter.create<PadOp>(loc, padded_type, op.input(), full_paddings);
930 
931     auto paddings_sum_type =
932         RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
933     // paddings_sum = paddings[*,0] + paddings[*,1]
934     auto paddings_split = rewriter.create<UnpackOp>(
935         loc, TypeRange({paddings_sum_type, paddings_sum_type}), full_paddings,
936         rewriter.getI64IntegerAttr(1));
937     auto paddings_sum = rewriter.create<AddV2Op>(
938         loc, paddings_split.getResult(0), paddings_split.getResult(1));
939 
940     auto input_shape_tensor = rewriter.create<ConstOp>(
941         loc,
942         DenseElementsAttr::get(
943             RankedTensorType::get({input_rank}, rewriter.getIntegerType(64)),
944             input_shape));
945 
946     // padded_shape_tensor is the shape of padded.
947     auto padded_shape_tensor =
948         rewriter.create<AddV2Op>(loc, paddings_sum, input_shape_tensor);
949 
950     auto zero_i32 = rewriter.create<ConstOp>(
951         loc, GetScalarOfType(rewriter.getIntegerType(32), 0));
952     SmallVector<Type, 4> padded_shape_splits_types(
953         input_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
954     SmallVector<Value, 4> padded_shape_splits(
955         rewriter
956             .create<SplitOp>(loc, padded_shape_splits_types, zero_i32,
957                              padded_shape_tensor)
958             .output());
959 
960     SmallVector<Type, 4> block_shape_splits_types(
961         block_rank, RankedTensorType::get({1}, rewriter.getIntegerType(64)));
962     SmallVector<Value, 4> block_shape_splits(
963         rewriter
964             .create<SplitOp>(loc, block_shape_splits_types, zero_i32,
965                              block_shape_i64)
966             .output());
967 
968     SmallVector<int64_t, 4> outer_shape_ints;
969     SmallVector<Value, 4> outer_shape_vals;
970     for (int64_t i = 0; i < block_rank; ++i) {
971       // TODO(b/157475606): Insert tf.Assert that the following division has
972       // remainder 0.
973       outer_shape_vals.push_back(rewriter.create<DivOp>(
974           loc, padded_shape_splits[1 + i], block_shape_splits[i]));
975 
976       auto padded_shape_i = padded_shape[1 + i];
977       auto block_shape_ints_i = block_shape_ints[i];
978 
979       // Compute the outer_shape constant values to infer the reshape.
980       if (padded_shape_i == -1 || block_shape_ints_i == -1) {
981         outer_shape_ints.push_back(-1);
982       } else {
983         outer_shape_ints.push_back(padded_shape_i / block_shape_ints_i);
984       }
985     }
986 
987     SmallVector<Value, 6> reshaped_shape_vals{padded_shape_splits[0]};
988     SmallVector<int64_t, 6> reshaped_shape_ints{padded_shape[0]};
989     for (int64_t i = 0; i < block_rank; ++i) {
990       reshaped_shape_vals.push_back(outer_shape_vals[i]);
991       reshaped_shape_vals.push_back(block_shape_splits[i]);
992 
993       reshaped_shape_ints.push_back(outer_shape_ints[i]);
994       reshaped_shape_ints.push_back(block_shape_ints[i]);
995     }
996     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
997       reshaped_shape_vals.push_back(padded_shape_splits[i]);
998       reshaped_shape_ints.push_back(padded_shape[i]);
999     }
1000     auto reshaped_shape = ValuesToRank1(
1001         rewriter, loc, rewriter.getIntegerType(64), reshaped_shape_vals);
1002 
1003     auto reshaped = rewriter.create<ReshapeOp>(
1004         loc, RankedTensorType::get(reshaped_shape_ints, element_type), padded,
1005         reshaped_shape);
1006 
1007     SmallVector<int64_t, 6> permutation_vals;
1008     for (int64_t i = 0; i < block_rank; ++i) {
1009       permutation_vals.push_back(2 + 2 * i);
1010     }
1011     permutation_vals.push_back(0);
1012     for (int64_t i = 0; i < block_rank; ++i) {
1013       permutation_vals.push_back(1 + 2 * i);
1014     }
1015     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1016       permutation_vals.push_back(block_rank + i);
1017     }
1018     auto permutation = rewriter.create<ConstOp>(
1019         loc, GetI64ElementsAttr(permutation_vals, &rewriter));
1020 
1021     auto permuted = rewriter.create<TransposeOp>(loc, reshaped, permutation);
1022     auto output_batch = padded_shape_splits[0];
1023     for (int64_t i = 0; i < block_rank; ++i) {
1024       output_batch =
1025           rewriter.create<MulOp>(loc, output_batch, block_shape_splits[i]);
1026     }
1027     SmallVector<Value, 4> output_shape_vals{output_batch};
1028     for (int64_t i = 0; i < block_rank; ++i) {
1029       output_shape_vals.push_back(outer_shape_vals[i]);
1030     }
1031     for (int64_t i = 1 + block_rank; i < input_rank; ++i) {
1032       output_shape_vals.push_back(padded_shape_splits[i]);
1033     }
1034     auto output_shape = ValuesToRank1(
1035         rewriter, loc, rewriter.getIntegerType(64), output_shape_vals);
1036 
1037     // Sometimes the result type is more specific than what the reshape builder
1038     // can infer.
1039     auto result_type = op.getResult().getType();
1040     rewriter.replaceOpWithNewOp<ReshapeOp>(op, result_type, permuted,
1041                                            output_shape);
1042 
1043     return success();
1044   }
1045 };
1046 
1047 class LowerBatchToSpaceND : public RewritePattern {
1048  public:
LowerBatchToSpaceND(MLIRContext * context)1049   explicit LowerBatchToSpaceND(MLIRContext *context)
1050       : RewritePattern(BatchToSpaceNDOp::getOperationName(), 1, context,
1051                        {
1052                            ConstOp::getOperationName(),
1053                            ReshapeOp::getOperationName(),
1054                            SliceOp::getOperationName(),
1055                            TransposeOp::getOperationName(),
1056                        }) {}
1057 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1058   LogicalResult matchAndRewrite(Operation *src_op,
1059                                 PatternRewriter &rewriter) const override {
1060     auto op = cast<BatchToSpaceNDOp>(src_op);
1061     auto input = op.input();
1062     auto input_ty = input.getType().cast<ShapedType>();
1063     auto element_ty = input_ty.getElementType();
1064     if (!input_ty.hasStaticShape()) {
1065       return failure();
1066     }
1067 
1068     const int input_rank = input_ty.getRank();
1069     auto input_shape = input_ty.getShape();
1070 
1071     DenseIntElementsAttr block_shape;
1072     DenseIntElementsAttr crops;
1073     if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) ||
1074         !matchPattern(op.crops(), m_Constant(&crops))) {
1075       return failure();
1076     }
1077 
1078     auto block_shape_ty = block_shape.getType();
1079     if (!block_shape_ty.hasRank() || block_shape_ty.getRank() != 1) {
1080       return failure();
1081     }
1082 
1083     const int block_rank = block_shape_ty.getShape().front();
1084     auto remainder_shape = input_shape.drop_front(1 + block_rank);
1085 
1086     const int64_t batch_size = input_shape[0];
1087 
1088     // Compute the product of the block_shape values.
1089     int64_t block_num_elems = 1;
1090 
1091     for (auto val : block_shape.getValues<APInt>()) {
1092       block_num_elems *= val.getSExtValue();
1093     }
1094 
1095     if (block_num_elems <= 0) {
1096       op.emitOpError()
1097           << "The product of the block dimensions must be positive";
1098       return failure();
1099     }
1100 
1101     // 1. Reshape `input` to `reshaped` of shape:
1102     //      [block_shape[0], ..., block_shape[M-1],
1103     //       batch / prod(block_shape),
1104     //       input_shape[1], ..., input_shape[N-1]]
1105     SmallVector<int64_t> reshaped_shape;
1106     reshaped_shape.reserve(block_shape.size());
1107     for (auto val : block_shape) {
1108       reshaped_shape.push_back(val.getSExtValue());
1109     }
1110     reshaped_shape.resize(input_rank + block_rank);
1111 
1112     reshaped_shape[block_rank] = batch_size / block_num_elems;
1113     std::copy(input_shape.begin() + 1, input_shape.end(),
1114               reshaped_shape.begin() + block_rank + 1);
1115 
1116     auto reshaped = rewriter.create<TF::ReshapeOp>(
1117         op.getLoc(), RankedTensorType::get(reshaped_shape, element_ty), input,
1118         rewriter.create<ConstOp>(op.getLoc(),
1119                                  rewriter.getI64TensorAttr(reshaped_shape)));
1120 
1121     // 2. Permute dimensions of `reshaped` to produce `permuted` of shape
1122     //      [batch / prod(block_shape),
1123     //
1124     //       input_shape[1], block_shape[0],
1125     //       ...,
1126     //       input_shape[M], block_shape[M-1],
1127     //
1128     //       input_shape[M+1], ..., input_shape[N-1]]
1129     SmallVector<int64_t> permutation(reshaped_shape.size());
1130     permutation[0] = block_rank;
1131     for (int i = 0; i < block_rank; ++i) {
1132       permutation[1 + 2 * i] = block_rank + 1 + i;
1133       permutation[1 + 2 * i + 1] = i;
1134     }
1135     std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1136               1 + block_rank * 2);
1137 
1138     SmallVector<int64_t> transpose_shape(permutation.size());
1139     for (auto it : llvm::enumerate(permutation)) {
1140       transpose_shape[it.index()] = reshaped_shape[it.value()];
1141     }
1142 
1143     auto permuted = rewriter.create<TF::TransposeOp>(
1144         op.getLoc(), RankedTensorType::get(transpose_shape, element_ty),
1145         reshaped,
1146         rewriter.create<ConstOp>(op.getLoc(),
1147                                  rewriter.getI64TensorAttr(permutation)));
1148 
1149     // 3. Reshape `permuted` to produce `reshaped_permuted` of shape
1150     //      [batch / prod(block_shape),
1151     //
1152     //       input_shape[1] * block_shape[0],
1153     //       ...,
1154     //       input_shape[M] * block_shape[M-1],
1155     //
1156     //       input_shape[M+1],
1157     //       ...,
1158     //       input_shape[N-1]]
1159     SmallVector<int64_t> reshaped_permuted_shape(input_rank);
1160     auto block_shape_values =
1161         llvm::to_vector<4>(block_shape.getValues<APInt>());
1162     reshaped_permuted_shape[0] = batch_size / block_num_elems;
1163     for (int i = 0; i < block_rank; ++i) {
1164       reshaped_permuted_shape[1 + i] =
1165           block_shape_values[i].getSExtValue() * input_shape[1 + i];
1166     }
1167     std::copy(remainder_shape.begin(), remainder_shape.end(),
1168               reshaped_permuted_shape.begin() + 1 + block_rank);
1169 
1170     auto reshaped_permuted = rewriter.create<TF::ReshapeOp>(
1171         op.getLoc(), RankedTensorType::get(reshaped_permuted_shape, element_ty),
1172         permuted,
1173         rewriter.create<ConstOp>(
1174             op.getLoc(), rewriter.getI64TensorAttr(reshaped_permuted_shape)));
1175 
1176     // 4. Crop the start and end of dimensions `[1, ..., M]` of
1177     //    `reshaped_permuted` according to `crops` to produce the output of
1178     //    shape:
1179     //      [batch / prod(block_shape),
1180     //
1181     //       input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
1182     //       ...,
1183     //       input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
1184     //
1185     //       input_shape[M+1], ..., input_shape[N-1]]
1186     SmallVector<int64_t> start_indices(input_rank, 0);
1187     SmallVector<int64_t> slice_sizes = reshaped_permuted_shape;
1188     SmallVector<int64_t> strides(input_rank, 1);
1189     auto crop_values = llvm::to_vector<4>(crops.getValues<APInt>());
1190     for (int i = 0; i < block_rank; ++i) {
1191       int64_t crop_start = crop_values[i * 2].getSExtValue();
1192       int64_t crop_end = crop_values[i * 2 + 1].getSExtValue();
1193 
1194       if (crop_start < 0 || crop_end < 0) {
1195         op.emitOpError() << "Crops must be non-negative";
1196         return failure();
1197       }
1198 
1199       start_indices[i + 1] = crop_start;
1200       slice_sizes[i + 1] -= crop_start + crop_end;
1201 
1202       if (slice_sizes[i + 1] < 0) {
1203         op.emitOpError() << "Cropped size must be non-negative: start: "
1204                          << crop_start << " end: " << crop_end << " size "
1205                          << reshaped_permuted_shape[1 + i];
1206       }
1207     }
1208 
1209     rewriter.replaceOpWithNewOp<TF::SliceOp>(
1210         op, RankedTensorType::get(slice_sizes, element_ty), reshaped_permuted,
1211         rewriter.create<ConstOp>(op.getLoc(),
1212                                  rewriter.getI64TensorAttr(start_indices)),
1213         rewriter.create<ConstOp>(op.getLoc(),
1214                                  rewriter.getI64TensorAttr(slice_sizes)));
1215     return success();
1216   }
1217 };
1218 
1219 // Lowers `SparseMatMulOp` to `MatMulOp`, ignoring the sparseness hints,
1220 // since we currently don't have an implementation that can use this
1221 // information. Adds appropriate casts where necessary to align element types
1222 // of operands and result for `MatMulOp`.
1223 class LowerSparseMatMulOp : public RewritePattern {
1224  public:
LowerSparseMatMulOp(MLIRContext * context)1225   explicit LowerSparseMatMulOp(MLIRContext *context)
1226       : RewritePattern(
1227             SparseMatMulOp::getOperationName(), 1, context,
1228             {CastOp::getOperationName(), MatMulOp::getOperationName()}) {}
1229 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1230   LogicalResult matchAndRewrite(Operation *src_op,
1231                                 PatternRewriter &rewriter) const override {
1232     auto op = cast<SparseMatMulOp>(src_op);
1233 
1234     // Result type must be f32 for applying the pattern (currently this is
1235     // required by the op anyway but this might change).
1236     if (!op.product().getType().cast<TensorType>().getElementType().isF32()) {
1237       return failure();
1238     }
1239     MLIRContext *context = rewriter.getContext();
1240     llvm::SmallVector<Value, 2> operands{op.a(), op.b()};
1241     for (Value &operand : operands) {
1242       TensorType tensor_type = operand.getType().cast<TensorType>();
1243       Type element_type = tensor_type.getElementType();
1244       if (element_type.isF32()) continue;
1245       // Element type can either be f32 or bf16 for `SparseMatMulOp` so it
1246       // must be bf16 here.
1247       assert(element_type.isBF16());
1248       Type tensor_type_f32;
1249       if (tensor_type.hasRank()) {
1250         tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(),
1251                                                 FloatType::getF32(context));
1252       } else {
1253         tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context));
1254       }
1255       // Add cast to f32 to conform with element type of result.
1256       operand = rewriter.create<CastOp>(op.getLoc(), tensor_type_f32, operand);
1257     }
1258     Value result = rewriter.create<MatMulOp>(
1259         op.getLoc(), op.product().getType(), operands[0], operands[1],
1260         op.transpose_a(), op.transpose_b());
1261 
1262     rewriter.replaceOp(op, {result});
1263     return success();
1264   }
1265 };
1266 
1267 // Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that
1268 // were fused together.
1269 class Lower_UnaryOpsComposition
1270     : public OpRewritePattern<_UnaryOpsCompositionOp> {
1271  public:
1272   using OpRewritePattern<_UnaryOpsCompositionOp>::OpRewritePattern;
1273 
matchAndRewrite(_UnaryOpsCompositionOp op,PatternRewriter & rewriter) const1274   LogicalResult matchAndRewrite(_UnaryOpsCompositionOp op,
1275                                 PatternRewriter &rewriter) const override {
1276     Value result = op.x();
1277     for (StringRef op_name : op.op_names().getAsValueRange<StringAttr>()) {
1278       std::string full_name = "tf." + op_name.str();
1279       // All ops in the sequences have the same result type as the original
1280       // result type.
1281       OperationState state(op.getLoc(), full_name, /*operands=*/{result},
1282                            /*types=*/{op.getType()}, /*attributes=*/{});
1283       Operation *op = rewriter.create(state);
1284       result = op->getResult(0);
1285     }
1286     rewriter.replaceOp(op, {result});
1287     return success();
1288   }
1289 };
1290 
1291 // Lowers ResizeNearestNeighbor to an indices computations with a gather along
1292 // the combined spatial dimensions. Generating the indices along the
1293 // width/height index could be used to gather along each of W and H dimension
1294 // of the input image array. To reduce to a single gather, these indices are
1295 // combined, so a single gather can be performed along the combined spatial
1296 // dimensions.
1297 //
1298 // Images must take the shape [b, h, w, c] and size is a rank-1 length-2 tensor
1299 // containing the height and width values for the output tensor. This lowering
1300 // should work with a dynamic images array.
1301 //
1302 // For example, a scaling with image shape [1, 3, 3, 1] to [2, 2] and unaligned
1303 // corners would generate a [0, 1] lookup along both the x and y direction.
1304 // Then when combined to form the 1-D spatial index the values would be
1305 // [0, 1, 3, 4] which would gather along the reshape image tensor of shape
1306 // [1, 9, 1], reshaped to the final [1, 3, 3, 1].
1307 class LowerResizeNearestNeighbor : public RewritePattern {
1308  public:
LowerResizeNearestNeighbor(MLIRContext * context)1309   explicit LowerResizeNearestNeighbor(MLIRContext *context)
1310       : RewritePattern(ResizeNearestNeighborOp::getOperationName(), 1, context,
1311                        {
1312                            BroadcastToOp::getOperationName(),
1313                            ConstOp::getOperationName(),
1314                            DivOp::getOperationName(),
1315                            PackOp::getOperationName(),
1316                            RangeOp::getOperationName(),
1317                            ReshapeOp::getOperationName(),
1318                            ShapeOp::getOperationName(),
1319                            SplitOp::getOperationName(),
1320                            TransposeOp::getOperationName(),
1321                        }) {}
1322 
matchAndRewrite(Operation * src_op,PatternRewriter & rewriter) const1323   LogicalResult matchAndRewrite(Operation *src_op,
1324                                 PatternRewriter &rewriter) const override {
1325     auto op = cast<ResizeNearestNeighborOp>(src_op);
1326     auto loc = op.getLoc();
1327     auto result_ty = op.getType().cast<ShapedType>();
1328 
1329     auto input = op.images();
1330     auto input_ty = input.getType().cast<ShapedType>();
1331     auto input_element_ty = input_ty.getElementType();
1332     auto out_size = op.size();
1333     auto out_size_ty = out_size.getType().cast<ShapedType>();
1334     auto out_size_element_ty = out_size_ty.getElementType();
1335 
1336     // Input should be rank 4.
1337     if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1338       return failure();
1339     }
1340 
1341     // Check that out_size is rank-1, length-2. Otherwise the size is not legal.
1342     if (!out_size_ty.hasRank() || out_size_ty.getRank() != 1 ||
1343         out_size_ty.getShape()[0] != 2) {
1344       return failure();
1345     }
1346 
1347     // Extract the output width / height dim size.
1348     int out_height_constant = -1;
1349     int out_width_constant = -1;
1350     DenseIntElementsAttr out_size_cst;
1351     if (matchPattern(out_size, m_Constant(&out_size_cst))) {
1352       llvm::SmallVector<int64_t, 2> cst_size;
1353       for (auto val : out_size_cst.getValues<APInt>()) {
1354         cst_size.push_back(val.getSExtValue());
1355       }
1356 
1357       out_height_constant = cst_size[0];
1358       out_width_constant = cst_size[1];
1359 
1360       if (out_height_constant < 0 || out_width_constant < 0) return failure();
1361     }
1362 
1363     int out_spatial_cst = out_height_constant < 0 || out_width_constant < 0
1364                               ? -1
1365                               : out_height_constant * out_width_constant;
1366 
1367     // Input rank should be 4. Might be able to drop this requirement entirely
1368     // as its an input requirement.
1369     if (!input_ty.hasRank() || input_ty.getRank() != 4) {
1370       return failure();
1371     }
1372 
1373     int batch_cst = input_ty.getShape()[0];
1374     int channels_cst = input_ty.getShape()[3];
1375 
1376     int in_y_cst = input_ty.getShape()[1];
1377     int in_x_cst = input_ty.getShape()[2];
1378     int in_spatial_cst =
1379         in_y_cst < 0 || in_x_cst < 0 ? -1 : in_y_cst * in_x_cst;
1380 
1381     // TODO(suderman): Add support for these optional parameters.
1382     if (op.align_corners() == true || op.half_pixel_centers() == true) {
1383       return failure();
1384     }
1385 
1386     auto one =
1387         rewriter.create<ConstOp>(loc, GetScalarOfType(out_size_element_ty, 1));
1388 
1389     // Extract the image shape.
1390     Value input_shape = rewriter.create<ShapeOp>(
1391         loc, RankedTensorType::get({4}, rewriter.getI64Type()), input);
1392     input_shape = rewriter.create<CastOp>(
1393         loc, RankedTensorType::get({4}, out_size_element_ty), input_shape);
1394 
1395     auto scalar_dim_ty = RankedTensorType::get({}, out_size_element_ty);
1396     auto split_image_shape = rewriter.create<UnpackOp>(
1397         loc,
1398         TypeRange({scalar_dim_ty, scalar_dim_ty, scalar_dim_ty, scalar_dim_ty}),
1399         input_shape);
1400 
1401     // Extract the separate components from the input shape.
1402     auto batch = split_image_shape.getResult(0);
1403     auto in_y = split_image_shape.getResult(1);
1404     auto in_x = split_image_shape.getResult(2);
1405     auto channels = split_image_shape.getResult(3);
1406 
1407     auto in_count = rewriter.create<MulOp>(
1408         loc, RankedTensorType::get({}, out_size_element_ty), in_y, in_x);
1409 
1410     // Unpack and separate the out width/height.
1411     auto split_out_size = rewriter.create<UnpackOp>(
1412         loc, TypeRange({scalar_dim_ty, scalar_dim_ty}), out_size);
1413 
1414     auto out_y = split_out_size.getResult(0);
1415     auto out_x = split_out_size.getResult(1);
1416 
1417     auto out_count = rewriter.create<MulOp>(
1418         loc, RankedTensorType::get({}, out_size_element_ty), out_y, out_x);
1419 
1420     // Generate what the final output shape will look like.
1421     auto out_shape = rewriter.create<PackOp>(
1422         loc, RankedTensorType::get({4}, out_size_element_ty),
1423         ValueRange({batch, out_y, out_x, channels}));
1424 
1425     // Compute the indices along the vertical dimension.
1426     auto in_y_f32 = rewriter.create<CastOp>(
1427         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y);
1428     auto out_w_f32 = rewriter.create<CastOp>(
1429         loc, RankedTensorType::get({}, rewriter.getF32Type()), out_y);
1430 
1431     Value y_scale = rewriter.create<DivOp>(
1432         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_y_f32,
1433         out_w_f32);
1434 
1435     Value zero_f32 = rewriter.create<ConstOp>(
1436         loc, GetScalarOfType(rewriter.getF32Type(), 0.0));
1437     Value one_f32 = rewriter.create<ConstOp>(
1438         loc, GetScalarOfType(rewriter.getF32Type(), 1.0));
1439 
1440     Value y_range = rewriter.create<RangeOp>(
1441         loc,
1442         RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1443         zero_f32, out_w_f32, one_f32);
1444 
1445     y_range = rewriter.create<MulOp>(
1446         loc,
1447         RankedTensorType::get({out_height_constant}, rewriter.getF32Type()),
1448         y_range, y_scale);
1449 
1450     y_range = rewriter.create<CastOp>(
1451         loc, RankedTensorType::get({out_height_constant}, out_size_element_ty),
1452         y_range);
1453 
1454     y_range = rewriter.create<ReshapeOp>(
1455         loc,
1456         RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1457         y_range,
1458         rewriter.create<PackOp>(loc,
1459                                 RankedTensorType::get({2}, out_size_element_ty),
1460                                 ValueRange({out_y, one})));
1461 
1462     Value y_indices = rewriter.create<MulOp>(
1463         loc,
1464         RankedTensorType::get({out_height_constant, 1}, out_size_element_ty),
1465         y_range, in_x);
1466 
1467     // Compute the indices for the nearest neighbour lookup across the width
1468     // dim.
1469     auto in_x_f32 = rewriter.create<CastOp>(
1470         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x);
1471     auto out_h_f32 = rewriter.create<CastOp>(
1472         loc, RankedTensorType::get({}, rewriter.getF32Type()), out_x);
1473 
1474     Value x_scale = rewriter.create<DivOp>(
1475         loc, RankedTensorType::get({}, rewriter.getF32Type()), in_x_f32,
1476         out_h_f32);
1477 
1478     Value x_range = rewriter.create<RangeOp>(
1479         loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1480         zero_f32, out_h_f32, one_f32);
1481 
1482     x_range = rewriter.create<MulOp>(
1483         loc, RankedTensorType::get({out_width_constant}, rewriter.getF32Type()),
1484         x_range, x_scale);
1485 
1486     x_range = rewriter.create<CastOp>(
1487         loc, RankedTensorType::get({out_width_constant}, out_size_element_ty),
1488         x_range);
1489 
1490     Value x_indices = rewriter.create<ReshapeOp>(
1491         loc,
1492         RankedTensorType::get({1, out_width_constant}, out_size_element_ty),
1493         x_range,
1494         rewriter.create<PackOp>(loc,
1495                                 RankedTensorType::get({2}, out_size_element_ty),
1496                                 ValueRange({one, out_x})));
1497 
1498     // Generate the combined index array, reshape to be 1-D.
1499     Value indices = rewriter.create<AddV2Op>(
1500         loc,
1501         RankedTensorType::get({out_height_constant, out_width_constant},
1502                               out_size_element_ty),
1503         y_indices, x_indices);
1504 
1505     indices = rewriter.create<ReshapeOp>(
1506         loc, RankedTensorType::get({out_spatial_cst}, out_size_element_ty),
1507         indices,
1508         rewriter.create<ReshapeOp>(
1509             loc, RankedTensorType::get({1}, out_size_element_ty), out_count,
1510             rewriter.create<ConstOp>(loc, rewriter.getI64TensorAttr({1}))));
1511 
1512     // Group the spatial indices and gather along that combined index.
1513     Value input_collapsed_spatial = rewriter.create<ReshapeOp>(
1514         loc,
1515         RankedTensorType::get({batch_cst, in_spatial_cst, channels_cst},
1516                               input_element_ty),
1517         input,
1518         rewriter.create<PackOp>(loc,
1519                                 RankedTensorType::get({3}, out_size_element_ty),
1520                                 ValueRange({batch, in_count, channels})));
1521 
1522     Value gathered_values = rewriter.create<GatherV2Op>(
1523         loc,
1524         RankedTensorType::get({batch_cst, out_spatial_cst, channels_cst},
1525                               input_element_ty),
1526         input_collapsed_spatial, indices, /*axis=*/one);
1527 
1528     gathered_values =
1529         rewriter.create<ReshapeOp>(loc, result_ty, gathered_values, out_shape);
1530 
1531     rewriter.replaceOp(op, gathered_values);
1532     return success();
1533   }
1534 };
1535 
1536 struct LowerRollOp : public RewritePattern {
LowerRollOpmlir::TF::__anon9aef87ee0111::LowerRollOp1537   explicit LowerRollOp(MLIRContext *context)
1538       : RewritePattern(
1539             RollOp::getOperationName(), 1, context,
1540             {ConstOp::getOperationName(), SliceOp::getOperationName(),
1541              ConcatV2Op::getOperationName()}) {}
1542 
matchAndRewritemlir::TF::__anon9aef87ee0111::LowerRollOp1543   LogicalResult matchAndRewrite(Operation *op,
1544                                 PatternRewriter &rewriter) const override {
1545     auto tf_roll_op = cast<RollOp>(op);
1546 
1547     auto input_ty = tf_roll_op.input().getType().dyn_cast<RankedTensorType>();
1548     if (!input_ty || !input_ty.hasStaticShape()) {
1549       return rewriter.notifyMatchFailure(
1550           op, "require the type of input to have static shapes");
1551     }
1552 
1553     DenseIntElementsAttr shift_attr;
1554     Value shift = tf_roll_op.shift();
1555     auto shift_ranked_attr_type = shift.getType().dyn_cast<RankedTensorType>();
1556     if (!shift_ranked_attr_type ||
1557         !matchPattern(shift, m_Constant(&shift_attr))) {
1558       return failure();
1559     }
1560 
1561     DenseIntElementsAttr axis_attr;
1562     Value axis = tf_roll_op.axis();
1563     auto axis_ranked_attr_type = axis.getType().dyn_cast<RankedTensorType>();
1564     if (!axis_ranked_attr_type || !matchPattern(axis, m_Constant(&axis_attr))) {
1565       return failure();
1566     }
1567 
1568     // Combine duplicate axis and make sure they are in [0, rank(input)) range.
1569     auto input_shape = input_ty.getShape();
1570     int input_rank = input_shape.size();
1571     SmallVector<int32_t, 4> shift_map(input_rank, 0);
1572     for (int i = 0; i < axis_attr.getNumElements(); ++i) {
1573       int32_t axis_i = axis_attr.getValues<int32_t>()[i];
1574       if (axis_i < 0) axis_i += input_rank;
1575       int32_t shift_i = shift_attr.getValues<int32_t>()[i];
1576       shift_map[axis_i] += shift_i;
1577     }
1578 
1579     SmallVector<int32_t, 4> adjusted_axis;
1580     SmallVector<int32_t, 4> adjusted_shift;
1581     for (int i = 0; i < input_rank; ++i) {
1582       int32_t input_dims_i = input_shape[i];
1583       int32_t shift_i = shift_map[i] % input_dims_i;
1584       if (shift_i < 0) shift_i += input_dims_i;
1585       if (shift_i == 0) continue;
1586       adjusted_axis.push_back(i);
1587       adjusted_shift.push_back(shift_i);
1588     }
1589 
1590     // Convert rolling in each dimension to two Slice ops and one Concat op.
1591     auto axis_type =
1592         RankedTensorType::get({input_rank}, rewriter.getIntegerType(64));
1593     auto create_slice_op = [&](int32_t axis_i, int32_t begin_i, int32_t size_i,
1594                                Value input) {
1595       SmallVector<int64_t, 4> begin_values(input_rank, 0);
1596       begin_values[axis_i] = begin_i;
1597       auto begin_attr = DenseIntElementsAttr::get(axis_type, begin_values);
1598       auto begin =
1599           rewriter.create<ConstOp>(op->getLoc(), axis_type, begin_attr);
1600 
1601       SmallVector<int64_t, 4> output_shape;
1602       output_shape.append(input_shape.begin(), input_shape.end());
1603       output_shape[axis_i] = size_i;
1604       auto size_attr = DenseIntElementsAttr::get(axis_type, output_shape);
1605       auto size = rewriter.create<ConstOp>(op->getLoc(), axis_type, size_attr);
1606 
1607       auto slice_op_ty =
1608           RankedTensorType::get(output_shape, input_ty.getElementType());
1609       return rewriter.create<SliceOp>(op->getLoc(), slice_op_ty, input, begin,
1610                                       size);
1611     };
1612 
1613     auto result = tf_roll_op.input();
1614     auto scalar_type =
1615         mlir::RankedTensorType::get({}, rewriter.getIntegerType(32));
1616     for (int i = 0; i < adjusted_axis.size(); ++i) {
1617       int32_t axis_i = adjusted_axis[i];
1618       int32_t shift_i = adjusted_shift[i];
1619       auto slice_op_1 = create_slice_op(axis_i, input_shape[axis_i] - shift_i,
1620                                         shift_i, result);
1621       auto slice_op_2 =
1622           create_slice_op(axis_i, 0, input_shape[axis_i] - shift_i, result);
1623 
1624       auto dim_attr = DenseIntElementsAttr::get(scalar_type, {axis_i});
1625       auto concat_dim =
1626           rewriter.create<ConstOp>(op->getLoc(), scalar_type, dim_attr);
1627       auto concat_op = rewriter.create<ConcatV2Op>(
1628           op->getLoc(), input_ty,
1629           ArrayRef<Value>({slice_op_1.output(), slice_op_2.output()}),
1630           concat_dim);
1631       result = concat_op.getResult();
1632     }
1633 
1634     rewriter.replaceOp(op, result);
1635     return success();
1636   }
1637 };
1638 
1639 // Decomposes Softmax and LogSoftmax to primitive TF ops, using the following
1640 // formulas:
1641 //
1642 //     softmax = div(exp(logits), sum(exp(logits)))
1643 //     log_softmax = sub(logits, log(sum(exp(logits))))
1644 //
1645 // TODO(jpienaar): Evaluate benefit of templating here.
1646 template <typename OpTy, bool use_log = true>
1647 class LowerSoftmaxOp : public OpRewritePattern<OpTy> {
1648  public:
1649   using OpRewritePattern<OpTy>::OpRewritePattern;
1650 
matchAndRewrite(OpTy op,PatternRewriter & rewriter) const1651   LogicalResult matchAndRewrite(OpTy op,
1652                                 PatternRewriter &rewriter) const override {
1653     Value logits = op.logits();
1654     auto loc = op.getLoc();
1655 
1656     // Note that the TensorFlow Softmax op verifies that the input rank is
1657     // greater than or equal to one so the following sequence is valid.
1658     auto reduce_dim =
1659         rewriter.create<TF::ConstOp>(loc, GetI64ElementsAttr({-1}, &rewriter));
1660 
1661     // Exponential of input values and then their sum can be very large here.
1662     // Division with large denominator is numerically unstable. To improve
1663     // numerical stability, subtract each batch with their max element so that
1664     // the maximum input value is zero. It can be shown that softmax computed
1665     // after adding or subtracting all inputs in a batch using a common value
1666     // gives mathematically equivalent result.
1667     auto max_logits =
1668         rewriter.create<TF::MaxOp>(loc, logits, reduce_dim,
1669                                    /*keep_dims=*/rewriter.getBoolAttr(true));
1670     auto shifted_logits = rewriter.create<TF::SubOp>(loc, logits, max_logits);
1671 
1672     // Exponentiate the inputs.
1673     Value exp = rewriter.create<TF::ExpOp>(loc, shifted_logits);
1674 
1675     // Compute summation of the exponentials.
1676     Value sum =
1677         rewriter.create<TF::SumOp>(loc, exp, reduce_dim,
1678                                    /*keep_dims=*/rewriter.getBoolAttr(true));
1679 
1680     if (use_log) {
1681       Value log = rewriter.create<TF::LogOp>(loc, sum);
1682       rewriter.replaceOpWithNewOp<TF::SubOp>(op, shifted_logits, log);
1683     } else {
1684       rewriter.replaceOpWithNewOp<TF::DivOp>(op, exp, sum);
1685     }
1686     return success();
1687   }
1688 };
1689 
1690 }  // namespace
1691 
PopulateLoweringTFPatterns(MLIRContext * context,RewritePatternSet * patterns)1692 void PopulateLoweringTFPatterns(MLIRContext *context,
1693                                 RewritePatternSet *patterns) {
1694   // clang-format off
1695   patterns->add<
1696       LowerAddNOp,
1697       LowerExp1mOp,
1698       ConvertFakeQuantWithMinMaxVarsOp,
1699       LowerDynamicStitchOp<DynamicStitchOp>,
1700       LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1701       LowerInvertPermutationOp,
1702       LowerLgammaOp,
1703       LowerPackOp,
1704       LowerBatchToSpaceND,
1705       LowerSpaceToBatchNDOp,
1706       LowerResizeNearestNeighbor,
1707       LowerSparseMatMulOp,
1708       Lower_UnaryOpsComposition,
1709       LowerRollOp>(context);
1710   // clang-format on
1711   populateWithGenerated(*patterns);
1712 }
1713 
PopulateTFLoweringBeforeHLOPatterns(MLIRContext * context,RewritePatternSet * patterns)1714 void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context,
1715                                          RewritePatternSet *patterns) {
1716   // clang-format off
1717   patterns->add<
1718       ConvertFakeQuantWithMinMaxVarsOp,
1719       LowerAddNOp,
1720       LowerBatchToSpaceND,
1721       LowerDynamicStitchOp<DynamicStitchOp>,
1722       LowerDynamicStitchOp<ParallelDynamicStitchOp>,
1723       LowerInvertPermutationOp,
1724       LowerPackOp,
1725       LowerResizeNearestNeighbor,
1726       LowerSoftmaxOp<TF::LogSoftmaxOp, /*use_log=*/true>,
1727       LowerSoftmaxOp<TF::SoftmaxOp, /*use_log=*/false>,
1728       LowerSpaceToBatchNDOp,
1729       LowerSparseMatMulOp,
1730       Lower_UnaryOpsComposition,
1731       LowerRollOp>(context);
1732   // clang-format on
1733 
1734   // Populate the relevant generated patterns.
1735   // clang-format off
1736   patterns->add<
1737       LowerAddOp,
1738       LowerBiasAddGradOp,
1739       LowerDivNoNanOp,
1740       LowerEmptyOp,
1741       LowerFakeQuantWithMinMaxArgs,
1742       LowerFillOp,
1743       LowerInv,
1744       LowerIsNanOp,
1745       LowerL2LossOp,
1746       LowerMulNoNanOp,
1747       LowerPadOp,
1748       LowerReciprocal,
1749       LowerRintOp,
1750       LowerRoundOpOnFloatTensor,
1751       LowerRoundOpOnIntTensor,
1752       LowerRsqrtGradOp,
1753       LowerScatterNdOp,
1754       LowerSeluOp,
1755       LowerSeluGradOp,
1756       LowerSizeOp,
1757       LowerSoftmaxCrossEntropyWithLogitsOp,
1758       LowerSparseSoftmaxCrossEntropyWithLogitsOp,
1759       LowerSqrtGradOp,
1760       LowerSquareOp,
1761       LowerSquaredDifferenceOpOnRealTensors,
1762       LowerSquaredDifferenceOpOneComplexTensors,
1763       LowerTanhGradOp,
1764       LowerTruncateDivOp,
1765       LowerXdivyOp,
1766       LowerXlog1pyOp,
1767       LowerXlogyOp>(context);
1768   // clang-format on
1769 }
1770 
PopulateLoweringQuantizedPatterns(MLIRContext * context,RewritePatternSet * patterns)1771 void PopulateLoweringQuantizedPatterns(MLIRContext *context,
1772                                        RewritePatternSet *patterns) {
1773   // clang-format off
1774   patterns->add<
1775       LowerDequantizeOp>(context);
1776   // clang-format on
1777 }
1778 
1779 }  // namespace TF
1780 }  // namespace mlir
1781