xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/optimize.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This transformation pass takes operations in TensorFlowLite dialect and
17 // optimizes them to resulting operations in TensorFlowLite dialect.
18 
19 #include <algorithm>
20 #include <climits>
21 #include <cstdint>
22 #include <functional>
23 #include <iterator>
24 #include <map>
25 #include <numeric>
26 #include <utility>
27 
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallSet.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/ADT/StringSwitch.h"
38 #include "llvm/Support/Casting.h"
39 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
40 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
41 #include "mlir/IR/Attributes.h"  // from @llvm-project
42 #include "mlir/IR/Builders.h"  // from @llvm-project
43 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
44 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
45 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
46 #include "mlir/IR/Matchers.h"  // from @llvm-project
47 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
48 #include "mlir/IR/Value.h"  // from @llvm-project
49 #include "mlir/Pass/Pass.h"  // from @llvm-project
50 #include "mlir/Support/LLVM.h"  // from @llvm-project
51 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
52 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
53 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
54 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
55 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
56 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
57 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
58 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
59 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
60 
61 namespace mlir {
62 namespace TFL {
63 
64 //===----------------------------------------------------------------------===//
65 // The actual Optimize Pass.
66 namespace {
67 #define GEN_PASS_CLASSES
68 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
69 
70 constexpr char kRelu[] = "RELU";
71 constexpr char kRelu6[] = "RELU6";
72 constexpr char kRelu1[] = "RELU_N1_TO_1";
73 
L2NormalizeReduceAxis(Value sq_op,DenseElementsAttr axis)74 bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
75   if (axis.getNumElements() == 0) {
76     return false;
77   }
78   if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
79           *axis.getValues<int>().begin() ||
80       *axis.getValues<int>().begin() == -1) {
81     return true;
82   }
83   if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
84     return false;
85   }
86   auto shape = sq_op.getType().cast<ShapedType>();
87   SmallVector<int, 4> elems{axis.getValues<int>().begin(),
88                             axis.getValues<int>().end()};
89   for (int i = 0; i < shape.getRank(); ++i) {
90     if (i != elems[i]) return false;
91   }
92   return true;
93 }
94 
95 using ::llvm::cast;
96 
97 // Optimize TFLite operations in functions.
98 class OptimizePass : public OptimizePassBase<OptimizePass> {
99  public:
100   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass)
101 
102   OptimizePass() = default;
OptimizePass(const OptimizePass &)103   OptimizePass(const OptimizePass &) {}
OptimizePass(bool enable_canonicalization)104   explicit OptimizePass(bool enable_canonicalization) {
105     this->enable_canonicalization_ = enable_canonicalization;
106   }
107 
108   void runOnOperation() override;
109 };
110 
111 // Returns whether the given type `a` is broadcast-compatible with `b`.
IsBroadcastableElementsAttrAndType(Type a,Type b)112 bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
113   return OpTrait::util::getBroadcastedType(a, b) != Type();
114 }
115 
116 // Returns whether the resultant type of any broadcastable operation with
117 // operands `a` and `b` matches `expected_output`. Returns false if `a` is not
118 // broadcast-compatible with `b`.
OperandsBroadcastToOutputType(Type a,Type b,Type expected_output)119 bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) {
120   Type output_element_type =
121       expected_output.cast<ShapedType>().getElementType();
122   Type broadcasted_type =
123       OpTrait::util::getBroadcastedType(a, b, output_element_type);
124   return broadcasted_type != Type() && broadcasted_type == expected_output;
125 }
126 
127 // Returns whether if `type1` dimensions are the same as the ending dimensions
128 // of `type2`. This is more restricted than broadcastable.
IsTailOfShape(Type type1,Type type2)129 bool IsTailOfShape(Type type1, Type type2) {
130   auto tail_type = type1.dyn_cast<ShapedType>();
131   auto full_type = type2.dyn_cast<ShapedType>();
132   if (!tail_type || !full_type || !tail_type.hasRank() ||
133       !full_type.hasRank() || tail_type.getRank() > full_type.getRank())
134     return false;
135   auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
136   auto i2 = full_type.getShape().rbegin();
137   return std::equal(i1, e1, i2);
138 }
139 
CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,const ArrayRef<int64_t> elements_shape,bool is_depthwise)140 bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
141                                       const ArrayRef<int64_t> elements_shape,
142                                       bool is_depthwise) {
143   // Also, val tensor must be of rank 1 or 0 (scalar).
144   const auto elements_rank = elements_shape.size();
145   if (elements_rank != 1 && elements_rank != 0) {
146     return false;
147   }
148   auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back();
149   // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we
150   // can let binary op to broadcast elements.
151   if (elements_depth == 1) {
152     return true;
153   }
154 
155   // In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise Conv.
156   // For conv:
157   // Check if last dimension in filter equals the first dimension
158   // For depthwise conv:
159   // Check if the first in filter dimension equals the first dimension.
160   if (filter_shape.empty() ||
161       (is_depthwise ? filter_shape.back() != elements_depth
162                     : filter_shape[0] != elements_depth))
163     return false;
164   return true;
165 }
166 
CanFuseConvOrDepthwiseConv(Value filter,Attribute val,bool is_depthwise)167 bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
168                                 bool is_depthwise) {
169   const auto elements = val.dyn_cast<DenseElementsAttr>();
170   if (!elements) {
171     return false;
172   }
173   const auto elements_shape = elements.getType().getShape();
174   const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
175   return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
176                                           is_depthwise);
177 }
178 
CanFuseConvOrDepthwiseConv(Attribute filter,Attribute val,bool is_depthwise)179 bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
180                                 bool is_depthwise) {
181   if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
182     if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
183       return CanFuseConvOrDepthwiseConvShapes(
184           filter_elements.getType().getShape(), elements.getType().getShape(),
185           is_depthwise);
186     }
187   }
188   return false;
189 }
190 
191 // Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value
192 // of `indices` are from 0 to n-1, the output tensor are identical to the
193 // `params`.
CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,DenseIntElementsAttr indices,Type output_type)194 bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
195                                               DenseIntElementsAttr indices,
196                                               Type output_type) {
197   auto params_type = params.getType().dyn_cast<RankedTensorType>();
198   auto indices_type = indices.getType().dyn_cast<RankedTensorType>();
199   // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D
200   // `indices` means it gets the first row of `params`. As long as indices
201   // iterate the first row of `params`, the output is identical to input.
202   if (!params_type || !indices_type || indices_type.getRank() != 2 ||
203       indices_type.getDimSize(0) != params_type.getDimSize(0) ||
204       indices_type.getDimSize(1) != 1)
205     return false;
206 
207   // Checks the `params_type` is equal to `output_type`. If not equal, we
208   // cannot replace the scatter_nd/gather_nd op with `params`.
209   if (params_type != output_type) return false;
210 
211   // Checks the value in `indices` is from 0 to n-1.
212   int cur_value = 0;
213   for (const auto &v : indices.getValues<APInt>()) {
214     if (v.getSExtValue() != cur_value) return false;
215     ++cur_value;
216   }
217 
218   return true;
219 }
220 
221 // Returns true if we can eliminate the SliceOp. When the values of `begin` are
222 // all 0s and `size[i]` is equal to either -1 or `input.shape[i]`
223 // for each dim i, the output tensor is identical to `input`.
CanOptimizeIdentitySliceOp(Value input,Attribute begin,Attribute size)224 bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) {
225   // Checks if `begin` and `size` are i32 or i64.
226   auto begin_attr = begin.dyn_cast<DenseIntElementsAttr>();
227   auto size_attr = size.dyn_cast<DenseIntElementsAttr>();
228   if (!begin_attr || !size_attr) {
229     return false;
230   }
231 
232   auto begin_elem_ty = begin_attr.getType().getElementType();
233   if (!begin_elem_ty.isInteger(32) && !begin_elem_ty.isInteger(64)) {
234     return false;
235   }
236   auto size_elem_ty = size_attr.getType().getElementType();
237   if (!size_elem_ty.isInteger(32) && !size_elem_ty.isInteger(64)) {
238     return false;
239   }
240 
241   // Checks if `input` is ranked and its rank is equal to number of elements in
242   // `begin` and `size`.
243   auto input_ty = input.getType().cast<ShapedType>();
244   if (!input_ty.hasRank()) {
245     return false;
246   }
247 
248   int64_t rank = input_ty.getRank();
249   if (rank != begin_attr.getNumElements() ||
250       rank != size_attr.getNumElements()) {
251     return false;
252   }
253 
254   // Checks if `begin` is all 0s, and `size[i]` is equal to either -1 or
255   // `input.shape[i]`.
256   for (uint64_t i = 0; i < rank; ++i) {
257     if (begin_attr.getValues<APInt>()[i].getSExtValue() != 0) return false;
258     int64_t si = size_attr.getValues<APInt>()[i].getSExtValue();
259     if (si != -1 && si != input_ty.getDimSize(i)) return false;
260   }
261 
262   return true;
263 }
264 
265 // Expand Attribute 'a' to 4D with all 1s except 1 dimension.
266 // Which dimension depends on 'is_depthwise' is true or false.
ExpandTo4DForConvImpl(Attribute a,bool is_depthwise)267 ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
268   auto elements = a.dyn_cast<DenseElementsAttr>();
269   auto shape = elements.getType().getShape();
270   if (!shape.empty()) {
271     // Checks that elements are essentially 1d.
272     assert(elements.getNumElements() == shape.back());
273   }
274   std::vector<int64_t> shape_data = {1, 1, 1, 1};
275   const int vector_length = elements.getNumElements();
276   if (is_depthwise)
277     shape_data[3] = vector_length;
278   else
279     shape_data[0] = vector_length;
280   auto new_shape =
281       RankedTensorType::get(shape_data, elements.getType().getElementType());
282   return elements.reshape(new_shape);
283 }
284 
ExpandTo4DForConv(Attribute a)285 ElementsAttr ExpandTo4DForConv(Attribute a) {
286   return ExpandTo4DForConvImpl(a, false);
287 }
288 
ExpandTo4DForDepthwiseConv(Attribute a)289 ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
290   return ExpandTo4DForConvImpl(a, true);
291 }
292 
RescaleQtype(Type input,Attribute factor)293 TypeAttr RescaleQtype(Type input, Attribute factor) {
294   return quant::RescaleQuantizedType(input, factor);
295 }
296 
297 // Returns shape of a ranked tensor.
298 // Precondition: output_val's is ranked tensor.
GetShape(Value output_val)299 DenseElementsAttr GetShape(Value output_val) {
300   auto output_type = output_val.getType().cast<RankedTensorType>();
301   auto shape_vector = output_type.getShape();
302   std::vector<int32_t> shape;
303   shape.reserve(shape_vector.size());
304   for (auto shape_object : shape_vector) {
305     shape.push_back(shape_object);
306   }
307   return mlir::DenseElementsAttr::get(
308       RankedTensorType::get(
309           {static_cast<int>(shape.size())},
310           mlir::IntegerType::get(output_val.getContext(), 32)),
311       llvm::makeArrayRef(shape));
312 }
313 
GetShapeStrippedType(TypeAttr type_attr)314 static Type GetShapeStrippedType(TypeAttr type_attr) {
315   auto type = type_attr.getValue();
316   auto shaped_type = type.dyn_cast<ShapedType>();
317   if (shaped_type) {
318     return shaped_type.getElementType();
319   } else {
320     return type;
321   }
322 }
323 
324 // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in
325 // the specified `shape` and `false` otherwise.
ShapeMatchesReduceWithKeepAxes(Value input,const mlir::Attribute & axes,const mlir::Attribute & shape)326 static bool ShapeMatchesReduceWithKeepAxes(Value input,
327                                            const mlir::Attribute &axes,
328                                            const mlir::Attribute &shape) {
329   RankedTensorType type = input.getType().dyn_cast_or_null<RankedTensorType>();
330   if (!type) return false;
331 
332   DenseIntElementsAttr axes_attr =
333       axes.dyn_cast_or_null<DenseIntElementsAttr>();
334   DenseIntElementsAttr shape_attr =
335       shape.dyn_cast_or_null<DenseIntElementsAttr>();
336   if (!axes_attr || !shape_attr) return false;
337 
338   if (shape_attr.getNumElements() != type.getRank()) return false;
339 
340   llvm::SmallSet<uint64_t, 4> axes_set;
341   for (auto a : axes_attr.getValues<APInt>()) {
342     axes_set.insert(a.getZExtValue());
343   }
344 
345   auto type_shape = type.getShape();
346   for (uint64_t i = 0; i < type.getRank(); ++i) {
347     if (axes_set.contains(i)) {
348       if (shape_attr.getValues<APInt>()[i] != 1) return false;
349     } else {
350       if (shape_attr.getValues<APInt>()[i] != type_shape[i]) return false;
351     }
352   }
353   return true;
354 }
355 
356 // Returns `true` if all the `axes` dimensions of `input` are 1.
AreInputDimensionsOneInAxes(Value input,const mlir::Attribute & axes)357 static bool AreInputDimensionsOneInAxes(Value input,
358                                         const mlir::Attribute &axes) {
359   RankedTensorType input_type =
360       input.getType().dyn_cast_or_null<RankedTensorType>();
361   if (!input_type) return false;
362   auto type_shape = input_type.getShape();
363 
364   DenseIntElementsAttr axes_attr =
365       axes.dyn_cast_or_null<DenseIntElementsAttr>();
366   if (!axes_attr) return false;
367 
368   for (auto a : axes_attr.getValues<APInt>()) {
369     int64_t axis = a.getSExtValue();
370     if (axis < 0) {
371       axis += type_shape.size();
372     }
373     if (axis < 0 || axis >= type_shape.size()) {
374       // `axis` is not a valid axis in input.
375       return false;
376     }
377     if (type_shape[axis] != 1) {
378       return false;
379     }
380   }
381 
382   return true;
383 }
384 
FloatValueEquals(const Attribute & attr,double value)385 static bool FloatValueEquals(const Attribute &attr, double value) {
386   auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
387   if (!fp_attr) return false;
388 
389   if (fp_attr.isSplat()) {
390     return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
391   }
392   return llvm::all_of(fp_attr.getValues<APFloat>(), [value](const APFloat &f) {
393     return f.isExactlyValue(value);
394   });
395 }
396 
397 // Returns true if the value's element type is F32.
IsF32Value(Value value)398 bool IsF32Value(Value value) {
399   return value.getType().cast<ShapedType>().getElementType().isF32();
400 }
401 
402 // Returns the number of elements in attr if it is a DenseElementsAttr, 1
403 // otherwise, as an unranked int32 Attribute.
GetNumElementsOrOne(Attribute attr)404 Attribute GetNumElementsOrOne(Attribute attr) {
405   const auto dense_attr = attr.dyn_cast_or_null<DenseElementsAttr>();
406   int32_t num_elements = dense_attr ? dense_attr.getNumElements() : 1;
407 
408   OpBuilder builder(attr.getContext());
409 
410   return DenseIntElementsAttr::get(
411       RankedTensorType::get({}, builder.getI32Type()),
412       {llvm::APInt(32, num_elements, true)});
413 }
414 
HasExactlyTwoElements(Attribute attr)415 bool HasExactlyTwoElements(Attribute attr) {
416   const auto values = attr.dyn_cast_or_null<ElementsAttr>();
417   if (!values) return false;
418   return values.getNumElements() == 2;
419 }
420 
421 // Returns true if attr is a DenseIntElementsAttr with the last element equal 1.
IsLastElementEqualsOne(Attribute attr)422 bool IsLastElementEqualsOne(Attribute attr) {
423   const auto ints = attr.dyn_cast_or_null<DenseIntElementsAttr>();
424   if (!ints) return false;
425   if (ints.empty()) return false;
426   const auto last_element_index = ints.getNumElements() - 1;
427   const auto iterator = ints.value_begin<int>();
428   const int last_element = iterator[last_element_index];
429   return last_element == 1;
430 }
431 
432 // Reshapes value to a given shape.
ReshapeValueDroppingLastDim(OpBuilder & builder,Value value,Attribute shape)433 Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value,
434                                   Attribute shape) {
435   // This function is always guarded with IsLastElementEqualsOne(), so we could
436   // cast safely here.
437   const auto old_shape = shape.cast<DenseIntElementsAttr>();
438   auto iterator = old_shape.value_begin<int>();
439   SmallVector<int, 4> new_shape;
440   SmallVector<int64_t, 4> new_shape_i64;
441   for (int i = 0; i < old_shape.size() - 1; ++i) {
442     new_shape.push_back(*iterator);
443     new_shape_i64.push_back(*iterator);
444     ++iterator;
445   }
446   return builder.create<ReshapeOp>(
447       value.getLoc(),
448       RankedTensorType::get(
449           new_shape_i64, value.getType().cast<ShapedType>().getElementType()),
450       value,
451       builder.create<arith::ConstantOp>(
452           value.getLoc(), DenseIntElementsAttr::get(
453                               RankedTensorType::get({old_shape.size() - 1},
454                                                     builder.getI32Type()),
455                               new_shape)));
456 }
457 
458 // Returns true if val has a static shape and the last dimension equals 1.
IsLastDimensionEqualOne(Value val)459 bool IsLastDimensionEqualOne(Value val) {
460   const auto val_type = val.getType().cast<ShapedType>();
461   if (!val_type.hasStaticShape()) return false;
462   const auto val_shape = val_type.getShape();
463   if (val_shape.empty()) return false;
464   const auto last_element = *val_shape.rbegin();
465   return last_element == 1;
466 }
467 
468 // Returns true if attr is a DenseIntElementsAttr of int32 or int64 values or an
469 // incrementing sequence from 0 to N-1.
470 //
471 // If such a value is used in an Equal operator, it can be replaced with OneHot.
IsOneHotIndexAttribute(Attribute attr)472 bool IsOneHotIndexAttribute(Attribute attr) {
473   const auto dense_attr = attr.dyn_cast_or_null<DenseIntElementsAttr>();
474   if (!dense_attr) {
475     return false;
476   }
477   auto index_type = dense_attr.getType();
478   const auto index_elem_bits = index_type.getElementTypeBitWidth();
479   if (index_elem_bits != 32 && index_elem_bits != 64) {
480     return false;
481   }
482   if (index_type.getRank() != 1) {
483     return false;
484   }
485   const auto elems = dense_attr.value_begin<APInt>();
486   for (int i = 0; i < dense_attr.getNumElements(); ++i) {
487     if (i != elems[i]) {
488       return false;
489     }
490   }
491   return true;
492 }
493 
494 // Creates FullyConnected op from params and returns the output.
GetFcOutput(OpBuilder * builder,::mlir::Operation::result_range result,Value input,Value filter,Value bias,StringAttr fused_activation_function,StringAttr weights_format,BoolAttr keep_num_dims,BoolAttr asymmetric_quantize_inputs)495 mlir::Value GetFcOutput(OpBuilder *builder,
496                         ::mlir::Operation::result_range result, Value input,
497                         Value filter, Value bias,
498                         StringAttr fused_activation_function,
499                         StringAttr weights_format, BoolAttr keep_num_dims,
500                         BoolAttr asymmetric_quantize_inputs) {
501   auto fc_op = builder->create<FullyConnectedOp>(
502       result[0].getLoc(), result.getTypes(), input, filter, bias,
503       fused_activation_function, weights_format, keep_num_dims,
504       asymmetric_quantize_inputs);
505   return fc_op->getResult(0);
506 }
507 
508 // Returns true if 'value' represents a const ElementsAttr with all values
509 // equals to 0.0.
AllValuesAreZero(mlir::Value value)510 bool AllValuesAreZero(mlir::Value value) {
511   if (!value) return false;
512   DenseElementsAttr vals;
513   if (!matchPattern(value, m_Constant(&vals))) return false;
514   for (auto elem : vals.getValues<float>())
515     if (elem != 0.0f) return false;
516   return true;
517 }
518 
519 // Converts an Attribute with a single value of float or integral type to an
520 // Attribute holding a single value of float type. If attr has no elements, the
521 // result is 0.0f.
ConvertSingleElementAttrToFloatAttr(Attribute attr)522 Attribute ConvertSingleElementAttrToFloatAttr(Attribute attr) {
523   const auto dense_fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
524   if (dense_fp_attr) {
525     // Already float => return
526     return dense_fp_attr;
527   }
528 
529   OpBuilder builder(attr.getContext());
530 
531   const auto dense_int_attr = attr.dyn_cast<DenseIntElementsAttr>();
532   const auto int_values = dense_int_attr.getValues<APInt>();
533   float float_val = 0.0f;
534   if (!int_values.empty()) {
535     const APInt apint_val = *int_values.begin();
536     if (dense_int_attr.getType().getElementType().isSignedInteger()) {
537       // Get the sign-extended value (=>int64) if the type is signed.
538       float_val = apint_val.getSExtValue();
539     } else {
540       // Get the zero-extended value (=>uint64) if unsigned or signless.
541       float_val = apint_val.getZExtValue();
542     }
543   }
544   return DenseFPElementsAttr::get(
545       RankedTensorType::get({}, builder.getF32Type()),
546       {llvm::APFloat(float_val)});
547 }
548 
549 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
550 
551 // Fuse Add with proceeding FullyConnected.
552 // TODO(b/136285429): Move to tablegen when variadic is supported
553 struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
554   using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
555 
matchAndRewritemlir::TFL::__anonacf349380111::FuseFullyConnectedAndAdd556   LogicalResult matchAndRewrite(TFL::AddOp add_op,
557                                 PatternRewriter &rewriter) const override {
558     // Match Add.
559     DenseElementsAttr added_value;
560     Value constant_val = add_op.rhs();
561     if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
562 
563     // Match Fully Connected.
564     auto fc_op =
565         dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
566     if (!fc_op) return failure();
567 
568     // Check if the constant RHS is either 0D (scalar), or a 1D with
569     // `{num_channels}` shape.
570     auto constant_val_type = constant_val.getType().cast<TensorType>();
571 
572     // In TFLite FullyConnect definition, bias must be a 1D tensor where
573     // the number of elements is equal to the number of channels.
574     // If it's not 1D or 0D (which can be broadcasted to 1D), reject the
575     // matching.
576     bool is_scalar_rhs = false;
577     if (constant_val_type.getRank() == 0) {
578       is_scalar_rhs = true;
579     } else if (constant_val_type.getRank() != 1) {
580       return failure();
581     }
582 
583     Value filter = fc_op.filter();
584     Value bias = fc_op.bias();
585     ElementsAttr bias_value;
586     const bool is_none_bias = bias.getType().isa<NoneType>();
587     if (fc_op.fused_activation_function() != "NONE") return failure();
588 
589     if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
590       return failure();
591 
592     // Rewrite
593     if (is_none_bias) {
594       if (is_scalar_rhs) {
595         // If the `constant_val` is scalar, we must the shape of filter
596         // to properly broadcast the scalar to `{num_channels}` shape.
597 
598         // Get the number of channels if possible.
599         auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
600         // Filter must be a `2D` tensor with `{num_channels, num_features}`
601         // shape. The following check is rejecting unknown rank (-1).
602         if (filter_type == nullptr || filter_type.getRank() != 2) {
603           return failure();
604         }
605         int num_channels = filter_type.getShape()[0];
606 
607         // Create a zero tensor with shape {num_channels}, and the type need to
608         // be the same as constant_val.
609         // This is a way to gracefully handle scalar tensor. The Add will always
610         // be constant-folded away regardless if `constant_val` is a scalar or
611         // not.
612         RankedTensorType type = RankedTensorType::get(
613             {num_channels}, constant_val_type.getElementType());
614         auto attr = rewriter.getZeroAttr(type);
615         bias = rewriter.create<arith::ConstantOp>(add_op.getLoc(), type, attr);
616         auto none_af = rewriter.getStringAttr("NONE");
617         bias =
618             rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
619                 .output();
620       } else {
621         // If there no pre-existing bias and the `constant_val` is 1D, simply
622         // use `constant_val` as bias.
623         bias = constant_val;
624       }
625     } else {
626       auto none_af = rewriter.getStringAttr("NONE");
627       bias =
628           rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
629               .output();
630     }
631 
632     auto fc = rewriter.create<TFL::FullyConnectedOp>(
633         FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), add_op.getLoc()}),
634         add_op.getType(),
635         /*input=*/fc_op.input(),
636         /*filter=*/filter,
637         /*bias=*/bias,
638         /*fused_activation_function=*/
639         rewriter.getStringAttr(add_op.fused_activation_function()),
640         /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
641         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()),
642         /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
643     rewriter.replaceOp(add_op, fc.output());
644 
645     return success();
646   }
647 };
648 
649 // Replace ..
650 // FC(Add(lhs, rhs), filter, bias)
651 // .. with ..
652 // FC(lhs, filter, FC(rhs, filter, bias))
653 // .. if rhs, filter, and bias are all constants.
654 // The second FC will be constant folded to a single vector.
655 // TODO(b/136285429): Move to tablegen when variadic is supported
656 struct FuseAddAndFullyConnected
657     : public OpRewritePattern<TFL::FullyConnectedOp> {
658   using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
659 
matchAndRewritemlir::TFL::__anonacf349380111::FuseAddAndFullyConnected660   LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
661                                 PatternRewriter &rewriter) const override {
662     // This only works with default format.
663     if (fc_op.weights_format() != "DEFAULT") return failure();
664 
665     // Match Add.
666     auto add_op = dyn_cast_or_null<TFL::AddOp>(fc_op.input().getDefiningOp());
667     if (!add_op) return failure();
668     if (add_op.fused_activation_function() != "NONE") return failure();
669 
670     // Don't match adds where the added constant is not 1D.
671     {
672       auto addend_shape = add_op.rhs().getType().cast<ShapedType>();
673       if (!addend_shape.hasStaticShape()) return failure();
674       if (addend_shape.getShape().size() != 1) return failure();
675     }
676 
677     // Calculate new bias.  Generate a new FC; it will be constant folded.
678     auto old_bias = fc_op.bias();
679     if (!old_bias || old_bias.getType().isa<NoneType>()) {
680       // TODO(b/180752069): Figure out new bias' type when old bias is empty.
681       return failure();
682     }
683 
684     // The FC relies on constant folding, which is implemented on F32. Checks
685     // types to be F32.
686     {
687       if (!IsF32Value(add_op.rhs()) || !IsF32Value(fc_op.filter()) ||
688           !IsF32Value(old_bias))
689         return failure();
690     }
691 
692     auto new_bias = rewriter.create<TFL::FullyConnectedOp>(
693         fc_op.getLoc(), old_bias.getType(),
694         /*input=*/add_op.rhs(),
695         /*filter=*/fc_op.filter(),
696         /*bias=*/old_bias,
697         /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
698         /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
699         /*keep_num_dims=*/rewriter.getBoolAttr(true),
700         /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
701 
702     // Create the updated FC.
703     auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
704         FusedLoc::get(add_op.getContext(), {add_op.getLoc(), fc_op.getLoc()}),
705         fc_op.output().getTypes(),
706         /*input=*/add_op.lhs(),
707         /*filter=*/fc_op.filter(),
708         /*bias=*/*new_bias.output().begin(),
709         /*fused_activation_function=*/
710         rewriter.getStringAttr(fc_op.fused_activation_function()),
711         /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
712         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()),
713         /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
714     rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
715 
716     return success();
717   }
718 };
719 
720 // Replace ..
721 // FC(Mul(lhs, rhs), filter, bias)
722 // .. with ..
723 // FC(lhs, Mul(filter, rhs), bias)
724 // .. if rhs, filter, and bias are all constants.
725 // The generated Mul will be constant folded to a single matrix.
726 struct FuseMulAndFullyConnected
727     : public OpRewritePattern<TFL::FullyConnectedOp> {
728   using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
729 
matchAndRewritemlir::TFL::__anonacf349380111::FuseMulAndFullyConnected730   LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
731                                 PatternRewriter &rewriter) const override {
732     // This only works with default format.
733     if (fc_op.weights_format() != "DEFAULT") return failure();
734 
735     // Match Mul.
736     auto mul_op = dyn_cast_or_null<TFL::MulOp>(fc_op.input().getDefiningOp());
737     if (!mul_op) return failure();
738     if (mul_op.fused_activation_function() != "NONE") return failure();
739 
740     // Don't match muls where the multiplier constant is not 1D.
741     {
742       auto multiplier_shape = mul_op.rhs().getType().cast<ShapedType>();
743       if (!multiplier_shape.hasStaticShape()) return failure();
744       if (multiplier_shape.getShape().size() != 1) return failure();
745     }
746 
747     // We rely on constant folding, implemented only for F32. Check types.
748     if (!IsF32Value(mul_op.rhs()) || !IsF32Value(fc_op.filter())) {
749       return failure();
750     }
751 
752     auto location =
753         FusedLoc::get(mul_op.getContext(), {mul_op.getLoc(), fc_op.getLoc()});
754 
755     auto new_filter = rewriter.create<TFL::MulOp>(
756         location,
757         /*lhs=*/fc_op.filter(),
758         /*rhs=*/mul_op.rhs(),
759         /*fused_activation_function=*/rewriter.getStringAttr("NONE"));
760     // Create the updated FC.
761     auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
762         location, fc_op.output().getTypes(),
763         /*input=*/mul_op.lhs(),
764         /*filter=*/new_filter,
765         /*bias=*/fc_op.bias(),
766         /*fused_activation_function=*/
767         rewriter.getStringAttr(fc_op.fused_activation_function()),
768         /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
769         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()),
770         /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
771     rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
772 
773     return success();
774   }
775 };
776 
777 // TODO(b/136285429): Move to tablegen when variadic is supported.
778 template <typename ReluXOp, char const *Act>
779 struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
780   using OpRewritePattern<ReluXOp>::OpRewritePattern;
781 
matchAndRewritemlir::TFL::__anonacf349380111::FuseFullyConnectedAndReluX782   LogicalResult matchAndRewrite(ReluXOp relu_op,
783                                 PatternRewriter &rewriter) const override {
784     Operation *input = relu_op.getOperand().getDefiningOp();
785     if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
786     auto fully_connected_op = cast<FullyConnectedOp>(input);
787     if (fully_connected_op.fused_activation_function() != "NONE")
788       return failure();
789 
790     auto new_activation_func = rewriter.getStringAttr(Act);
791     auto new_weights_format =
792         rewriter.getStringAttr(fully_connected_op.weights_format());
793     auto new_keep_num_dims =
794         rewriter.getBoolAttr(fully_connected_op.keep_num_dims());
795     auto fc = rewriter.create<FullyConnectedOp>(
796         FusedLoc::get(relu_op.getContext(),
797                       {fully_connected_op.getLoc(), relu_op.getLoc()}),
798         relu_op.getType(), /*input=*/fully_connected_op.input(),
799         /*filter=*/fully_connected_op.filter(),
800         /*bias=*/fully_connected_op.bias(),
801         /*fused_activation_function=*/new_activation_func,
802         /*weights_format=*/new_weights_format,
803         /*keep_num_dims=*/new_keep_num_dims,
804         /*asymmetric_quantize_inputs=*/
805         fully_connected_op.asymmetric_quantize_inputsAttr());
806     rewriter.replaceOp(relu_op, fc.output());
807 
808     return success();
809   }
810 };
811 
812 // Fuse Mul with proceeding FullyConnected.
813 // TODO(b/136285429): Move to tablegen when variadic is supported
814 struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
815   using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
816 
matchAndRewritemlir::TFL::__anonacf349380111::FuseFullyConnectedAndMul817   LogicalResult matchAndRewrite(TFL::MulOp mul_op,
818                                 PatternRewriter &rewriter) const override {
819     // If we are broadcasting on the lhs then don't fold the multiply as it
820     // would increase the amount of compute done by the fully connected op.
821     if (mul_op.lhs().getType() != mul_op.getType()) return failure();
822 
823     // Mul.
824     DenseElementsAttr cst;
825     Value constant_val = mul_op.rhs();
826     if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
827 
828     // Fully Connected.
829     auto fc_op =
830         dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
831     if (!fc_op) return failure();
832     Value filter = fc_op.filter();
833     Value bias = fc_op.bias();
834     ElementsAttr cst_tmp;
835     if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
836     if (!bias.getType().isa<NoneType>() &&
837         !matchPattern(bias, m_Constant(&cst_tmp)))
838       return failure();
839     if (fc_op.fused_activation_function() != "NONE") return failure();
840 
841     // Only fuse multiplier if all dimensions other than the depth dimension
842     // are equal to 1 since otherwise
843     // `matmul(x, filter) * cst != matmul(x, filter * cst)`
844     // even if `filter` and `cst` are be broadcastable.
845     auto shape = cst.getType().getShape();
846     if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
847 
848     int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
849     // Expand and transpose the multiplier since weights are using the
850     // OHWI data format in TFLite.
851     int64_t normalized_shape[2] = {element_size, 1};
852     auto new_cst = cst.reshape(RankedTensorType::get(
853         normalized_shape, cst.getType().getElementType()));
854     Type new_type = new_cst.getType();
855     if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
856       return failure();
857     }
858 
859     auto new_op =
860         rewriter.create<arith::ConstantOp>(mul_op.getLoc(), new_type, new_cst);
861     Value new_const_val = new_op.getResult();
862 
863     // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
864     // TF::MulOp is used to fold the constant.
865     // TODO(b/139192933): switch to the TFL constant folding
866     auto new_filter =
867         rewriter.create<TF::MulOp>(mul_op.getLoc(), filter, new_const_val).z();
868     // If bias isn't None, it needs to be multiplied as well.
869     if (!bias.getType().isa<NoneType>()) {
870       bias =
871           rewriter.create<TF::MulOp>(mul_op.getLoc(), bias, constant_val).z();
872     }
873 
874     auto fc = rewriter.create<TFL::FullyConnectedOp>(
875         FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), mul_op.getLoc()}),
876         mul_op.getType(),
877         /*input=*/fc_op.input(),
878         /*filter=*/new_filter,
879         /*bias=*/bias,
880         /*fused_activation_function=*/
881         rewriter.getStringAttr(mul_op.fused_activation_function()),
882         /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
883         /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()),
884         /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
885     rewriter.replaceOp(mul_op, fc.output());
886 
887     return success();
888   }
889 };
890 
891 // Fuse Mul with proceeding Affine ops. This is an C++ implementation of the
892 // following table gen implementation, which doesn't derived the result type of
893 // the TFL_DequantizeOp.
894 // def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input,
895 //                          (TFL_DequantizeOp (TFL_QuantizeOp
896 //                              (Arith_ConstantOp F32ElementsAttr:$filter),
897 //                              $qtype)),
898 //                          (Arith_ConstantOp F32ElementsAttr:$bias),
899 //                          $h_factor, $w_factor, TFL_AF_None,
900 //                          $padding, $stride_h, $stride_w),
901 //                      (Arith_ConstantOp F32ElementsAttr:$value), $act_fn),
902 //           (TFL_Conv2DOp $input,
903 //                      (TFL_DequantizeOp (TFL_QuantizeOp
904 //                          (TFL_MulOp (Arith_ConstantOp $filter),
905 //                                     (Arith_ConstantOp (ExpandTo4DForConv
906 //                                     $value)),
907 //                                      TFL_AF_None),
908 //                          (RescaleQtype $qtype, $value))),
909 //                      (TFL_MulOp (Arith_ConstantOp $bias), (Arith_ConstantOp
910 //                      $value),
911 //                          TFL_AF_None),
912 //                      $h_factor, $w_factor, $act_fn,
913 //                      $padding, $stride_h, $stride_w),
914 //         [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
915 //          (HasOneUse $conv_output),
916 //          (IsPerAxisQuantization $qtype), // per-axis quantization
917 //         ]>;
918 template <typename AffineOpType>
919 struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
920   using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
921 
matchAndRewritemlir::TFL::__anonacf349380111::FuseAffinOpAndMulWithQDQs922   LogicalResult matchAndRewrite(TFL::MulOp mul_op,
923                                 PatternRewriter &rewriter) const override {
924     // Mul. Required 1-D rhs for batch normalization.
925     DenseElementsAttr gamma_cst;
926     Value gamma = mul_op.rhs();
927     if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
928     if (gamma_cst.getType().getRank() != 1) return failure();
929 
930     // Affine op
931     Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
932     auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
933     if (!fc_op) return failure();
934     Value filter = fc_op.filter();
935     Value bias = fc_op.bias();
936 
937     // QDQs
938     auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
939     if (!dq_op) return failure();
940     auto q_op =
941         dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
942     if (!q_op) return failure();
943     filter = q_op.input();
944 
945     // weight constant
946     ElementsAttr cst_tmp;
947     if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
948     if (!bias.getType().isa<NoneType>() &&
949         !matchPattern(bias, m_Constant(&cst_tmp)))
950       return failure();
951     if (fc_op.fused_activation_function() != "NONE") return failure();
952 
953     // Broadcast the constant operand of Mul if it isn't compatible to the
954     // filter input. We only support broadcasting the operand along the depth
955     // dimension, when the operand's depth is 1.
956     rewriter.setInsertionPoint(q_op);
957     Location loc = fc_op.getLoc();
958     Value broadcasted_gamma;
959     if (isa<TFL::Conv2DOp>(mul_op_lhs)) {
960       auto mul_rhs = ExpandTo4DForConv(gamma_cst);
961       broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
962     } else if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
963       auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
964       broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
965     } else {
966       return failure();
967     }
968 
969     // Make sure that the fused bias will be a 1D tensor.
970     auto gamma_shape = gamma.getType().cast<ShapedType>();
971     if (!gamma_shape.hasRank() || gamma_shape.getRank() != 1) {
972       return failure();
973     }
974 
975     // Rewrite filter constant. Since the folder of TFL::MulOp couldn't
976     // broadcast the operands, TF::MulOp is used to fold the constant.
977     auto new_filter =
978         rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
979     // Update the scale in the quantize op.
980     auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
981     if (!new_qtype) return failure();
982     rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
983                                                  new_filter, new_qtype);
984 
985     // If bias isn't None, it needs to be multiplied as well.
986     if (!bias.getType().isa<NoneType>()) {
987       rewriter.setInsertionPoint(fc_op);
988       auto new_bias = rewriter.create<TF::MulOp>(loc, bias, gamma);
989       fc_op.getOperation()->replaceUsesOfWith(bias, new_bias);
990     }
991 
992     // Remove the tailing mul op.
993     mul_op.replaceAllUsesWith(fc_op.getResult());
994     return success();
995   }
996 };
997 
998 using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs<TFL::Conv2DOp>;
999 using FuseDepthwiseConv2DAndMulWithQDQs =
1000     FuseAffinOpAndMulWithQDQs<TFL::DepthwiseConv2DOp>;
1001 
1002 // Fuse Binary Op with following Affine operation.
1003 template <typename AffineOpType>
1004 struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
1005   using OpRewritePattern<AffineOpType>::OpRewritePattern;
1006 
matchAndRewritemlir::TFL::__anonacf349380111::FuseBinaryOpToFollowingAffineOp1007   LogicalResult matchAndRewrite(AffineOpType fc_op,
1008                                 PatternRewriter &rewriter) const override {
1009     // Binary op.
1010     Operation *binary_op = fc_op.input().getDefiningOp();
1011     if (!binary_op || binary_op->getNumOperands() != 2) return failure();
1012     // We only handle the cases the RHS is a scalar.
1013     // TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
1014     // the constant operands are on the RHS, we need to consider LHS constant
1015     // operand if necessary.
1016     DenseFPElementsAttr cst;
1017     if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
1018       return failure();
1019     if (cst.getNumElements() != 1) return failure();
1020     APFloat cst_value = *cst.value_begin<APFloat>();
1021 
1022     // Affine op.
1023     Value filter = fc_op.filter();
1024     Value bias = fc_op.bias();
1025     DenseFPElementsAttr filter_cst, bias_cst;
1026     if (!matchPattern(filter, m_Constant(&filter_cst))) {
1027       // The filter maybe quantized, then we should set it to the real constant.
1028       auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
1029       if (!dq) return failure();
1030       auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
1031       if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
1032         return failure();
1033       }
1034       filter = q.input();
1035     }
1036     if (!bias.getType().isa<NoneType>() &&
1037         !matchPattern(bias, m_Constant(&bias_cst)))
1038       return failure();
1039     auto binary_op_activation_func =
1040         binary_op->template getAttrOfType<StringAttr>(
1041             "fused_activation_function");
1042     if (!binary_op_activation_func ||
1043         binary_op_activation_func.getValue() != "NONE")
1044       return failure();
1045     ShapedType filter_type = filter_cst.getType();
1046 
1047     if (llvm::isa<AddOp, SubOp>(binary_op)) {
1048       auto padding = fc_op->template getAttrOfType<StringAttr>("padding");
1049       if (padding && padding.getValue() != "VALID") return failure();
1050 
1051       // The fusion of add/sub is actually applying the following
1052       // transformation:
1053       // w * (x + c) + b => w * x + (w * c + b)
1054       // so we have to update the bias.
1055       if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
1056 
1057       auto bias_and_slice =
1058           GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
1059       int64_t bias_size = bias_and_slice.first;
1060       int64_t slice_size = bias_and_slice.second;
1061       ShapedType new_bias_type =
1062           RankedTensorType::get({bias_size}, filter_type.getElementType());
1063 
1064       // The new bias should be a 1-D tensor with length equals to the bias
1065       // dimension of the weight.
1066       SmallVector<APFloat, 4> new_bias_values;
1067       if (bias.getType().isa<NoneType>()) {  // none bias, a list of zeros
1068         new_bias_values.resize(bias_size,
1069                                APFloat::getZero(cst_value.getSemantics()));
1070       } else if (bias_cst.getNumElements() == 1) {  // scalar bias, broadcast it
1071         new_bias_values.resize(bias_size, *bias_cst.value_begin<APFloat>());
1072       } else if (bias_cst.getNumElements() == bias_size) {  // 1-d bias, copy it
1073         new_bias_values.insert(new_bias_values.begin(),
1074                                bias_cst.value_begin<APFloat>(),
1075                                bias_cst.value_end<APFloat>());
1076       } else {
1077         return failure();
1078       }
1079 
1080       int64_t flatten_index = 0;
1081       for (auto fp_it = filter_cst.value_begin<APFloat>(),
1082                 fp_end = filter_cst.value_end<APFloat>();
1083            fp_it != fp_end; ++fp_it) {
1084         int bias_index = (flatten_index++ / slice_size) % bias_size;
1085 
1086         new_bias_values[bias_index] =
1087             new_bias_values[bias_index] + *fp_it * cst_value;
1088       }
1089       auto new_bias = DenseFPElementsAttr::get(new_bias_type, new_bias_values);
1090       auto new_bias_op =
1091           rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
1092       fc_op.setOperand(0, binary_op->getOperand(0));
1093       fc_op.setOperand(2, new_bias_op);
1094     } else if (llvm::isa<MulOp, DivOp>(binary_op)) {
1095       // The fusion of mul/div is actually applying the following
1096       // transformation:
1097       // w * (x ' c) + b => (w ' c) x + b
1098       // so we have to update the weight.
1099       bool is_mul = llvm::isa<MulOp>(binary_op);
1100       auto new_filter =
1101           filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
1102             return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
1103           });
1104       // We recreate the constant op in case it is shared by the other ops. This
1105       // might increase the model size.
1106       auto new_filter_op = rewriter.create<ConstOp>(
1107           fc_op.getLoc(), filter.getType(), new_filter);
1108       fc_op.setOperand(0, binary_op->getOperand(0));
1109       if (fc_op.filter() != filter) {
1110         // This filter goes through quantize and dequantize ops. Then we just
1111         // need to update the weight to the quantize op.
1112         filter.replaceAllUsesWith(new_filter_op);
1113       } else {
1114         // This filter doesn't go through quantize and dequantize ops, Then
1115         // we update the weight of the affine op directly.
1116         fc_op.setOperand(1, new_filter_op);
1117       }
1118     } else {
1119       return failure();
1120     }
1121     return success();
1122   }
1123 
1124  private:
1125   // Returns the dimension length of the channel dimension and also the slide
1126   // size by each position in the channel dimension accordingly. tfl.conv2d and
1127   // tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d
1128   // has tailing channel dimension. This function is to provide a utility to
1129   // create the above information from the op property.
GetBiasDimAndSliceSizemlir::TFL::__anonacf349380111::FuseBinaryOpToFollowingAffineOp1130   static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
1131       ArrayRef<int64_t> filter_shape, AffineOpType op) {
1132     // Channel dimension index is specified as op property
1133     auto channel_index_iter = filter_shape.begin();
1134     std::advance(channel_index_iter, op.GetChannelDimIndex());
1135     // The slide size is the size of the data in higher dimensions.
1136     int64_t slice_size =
1137         std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
1138                         std::multiplies<int64_t>());
1139     return {*channel_index_iter, slice_size};
1140   }
1141 };
1142 
1143 // If the operand to a broadcastable op is a splat constant, try to replace it
1144 // with a 0-d constant, e.g. before this optimization,
1145 //   %cst = arith.constant dense<1.0> : tensor<16x16x4xf32>
1146 //   %0 = "tfl.conv_2d"...
1147 //   %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
1148 // After this optimization:
1149 //   %cst = arith.constant dense<1.0> : tensor<f32>
1150 //   %0 = "tfl.conv_2d"...
1151 //   %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
1152 // This pattern can enable more fusing opportunities when the binary op is
1153 // following conv ops.
1154 template <typename BinaryOpType>
1155 struct ScalarizeSplatConstantForBroadcastableOps
1156     : public OpRewritePattern<BinaryOpType> {
1157   using OpRewritePattern<BinaryOpType>::OpRewritePattern;
1158 
matchAndRewritemlir::TFL::__anonacf349380111::ScalarizeSplatConstantForBroadcastableOps1159   LogicalResult matchAndRewrite(BinaryOpType binary_op,
1160                                 PatternRewriter &rewriter) const override {
1161     DenseElementsAttr splat_elements_attr;
1162     if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
1163       return failure();
1164     }
1165 
1166     constexpr int kSplatOperandIndex = 1;
1167     auto result_type =
1168         binary_op.getResult().getType().template cast<ShapedType>();
1169     mlir::Value non_splat_operand =
1170         binary_op.getOperand(1 - kSplatOperandIndex);
1171     auto non_splat_operand_type =
1172         non_splat_operand.getType().cast<ShapedType>();
1173     // If the other operand's shape does not equal to the result shape, then we
1174     // cannot scalarize the splat constant because the result shape relies on
1175     // the splat constant op's shape for broadcasting.
1176     if (!non_splat_operand_type.hasStaticShape() ||
1177         non_splat_operand_type.getShape() != result_type.getShape() ||
1178         non_splat_operand_type.getRank() > 4) {
1179       return failure();
1180     }
1181 
1182     // If non-splat operand is not fusable affine ops, then no need to apply
1183     // this transformation.
1184     if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
1185       return failure();
1186     }
1187 
1188     // Creates a new scalar constant op using the splat value.
1189     mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
1190     auto scalar_elements_attr = DenseElementsAttr::get(
1191         RankedTensorType::get({},
1192                               splat_elements_attr.getType().getElementType()),
1193         splat_elements_attr.getSplatValue<mlir::Attribute>());
1194 
1195     auto scalar_constant_op = rewriter.create<arith::ConstantOp>(
1196         splat_operand.getLoc(), scalar_elements_attr.getType(),
1197         scalar_elements_attr);
1198 
1199     binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
1200     return success();
1201   }
1202 
1203  private:
1204   // Returns true if this value is a splat constant op which can be scalarized.
1205   // Also returns the elements attr if this value is indeed a splat constant.
IsScalarizableSplatConstantmlir::TFL::__anonacf349380111::ScalarizeSplatConstantForBroadcastableOps1206   bool IsScalarizableSplatConstant(mlir::Value value,
1207                                    DenseElementsAttr *elements_attr) const {
1208     if (!matchPattern(value, m_Constant(elements_attr))) {
1209       return false;
1210     }
1211     auto element_type = value.getType().cast<ShapedType>().getElementType();
1212     // Ignore per-axis quantized constants because after converting to scalar,
1213     // we will lose per-axis qantization parameter.
1214     if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
1215       return false;
1216     }
1217     if (IsScalar(value)) {
1218       return false;
1219     }
1220     return elements_attr->isSplat();
1221   }
1222 
1223   // If this type is a scalar shaped type.
IsScalarmlir::TFL::__anonacf349380111::ScalarizeSplatConstantForBroadcastableOps1224   bool IsScalar(mlir::Value value) const {
1225     auto type = value.getType().dyn_cast<ShapedType>();
1226     if (!type) {
1227       return false;
1228     }
1229     if (!type.hasStaticShape()) {
1230       return false;
1231     }
1232     return type.getNumElements() == 1;
1233   }
1234 
1235   // Returns true if we can fuse an affine op with consuming binary op.
CanFuseAffineOpmlir::TFL::__anonacf349380111::ScalarizeSplatConstantForBroadcastableOps1236   bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
1237     if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
1238                          TFL::FullyConnectedOp>(affine_op)) {
1239       return false;
1240     }
1241     DenseElementsAttr value;
1242     // Check that bias are constants if not none.
1243     Value bias = affine_op->getOperand(2);
1244     if (!bias.getType().isa<NoneType>() &&
1245         !matchPattern(bias, m_Constant(&value))) {
1246       return false;
1247     }
1248     // If the binary op is mul/div, also check that filter is constant.
1249     if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
1250         !matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
1251       return false;
1252     }
1253 
1254     // We can only fuse F32/BF16.
1255     auto is_fusable_type = [](Type t) {
1256       Type element_type = t;
1257       if (auto shaped_type = t.dyn_cast<ShapedType>()) {
1258         element_type = shaped_type.getElementType();
1259       }
1260       return element_type.isBF16() || element_type.isF32();
1261     };
1262     for (Type t : binary_op->getOperandTypes()) {
1263       if (!is_fusable_type(t)) {
1264         return false;
1265       }
1266     }
1267 
1268     return true;
1269   }
1270 };
1271 
1272 using ScalarizeSplatConstantForSub =
1273     ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
1274 using ScalarizeSplatConstantForAdd =
1275     ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
1276 using ScalarizeSplatConstantForMul =
1277     ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
1278 using ScalarizeSplatConstantForDiv =
1279     ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
1280 
1281 struct ConvertTrivialTransposeOpToReshapeOp
1282     : public OpRewritePattern<TFL::TransposeOp> {
1283   using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
1284 
matchAndRewritemlir::TFL::__anonacf349380111::ConvertTrivialTransposeOpToReshapeOp1285   LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
1286                                 PatternRewriter &rewriter) const override {
1287     auto input_type = transpose_op.input().getType().cast<ShapedType>();
1288     auto output_type = transpose_op.output().getType().cast<ShapedType>();
1289     // It's possible to know if the transformation is safe only if the input
1290     // & output shapes are fully known and permutation is a constant.
1291     if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
1292       return failure();
1293     Value perm = transpose_op.perm();
1294     DenseElementsAttr perm_values_attr;
1295     if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
1296 
1297     auto input_shape = input_type.getShape();
1298     SmallVector<int64_t, 8> perm_values;
1299     for (const auto &dim : perm_values_attr.getValues<APInt>())
1300       perm_values.push_back(dim.getSExtValue());
1301 
1302     // This should never happen unless the input graph is malformed.
1303     if (input_shape.size() != perm_values.size()) {
1304       transpose_op.emitError(
1305           "TransposeOP has inconsistent input and perm values.");
1306     }
1307 
1308     SmallVector<int, 8> old_major_index_ordering;
1309     SmallVector<int, 8> new_major_index_ordering;
1310     for (int i = 0, end = input_shape.size(); i < end; i++) {
1311       if (input_shape[i] != 1) {
1312         old_major_index_ordering.push_back(i);
1313       }
1314 
1315       if (input_shape[perm_values[i]] != 1) {
1316         new_major_index_ordering.push_back(perm_values[i]);
1317       }
1318     }
1319     if (old_major_index_ordering != new_major_index_ordering) {
1320       return failure();
1321     }
1322 
1323     // Rewrite.
1324     Location loc = transpose_op.getLoc();
1325 
1326     SmallVector<int32_t, 8> output_shape_values;
1327     for (auto dim : output_type.getShape()) {
1328       output_shape_values.push_back(dim);
1329     }
1330     auto type = mlir::RankedTensorType::get(output_shape_values.size(),
1331                                             rewriter.getIntegerType(32));
1332     auto new_shape_attr =
1333         mlir::DenseIntElementsAttr::get(type, output_shape_values);
1334     auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr);
1335 
1336     rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
1337         transpose_op, transpose_op.output().getType(), transpose_op.input(),
1338         new_shape);
1339 
1340     return success();
1341   }
1342 };
1343 
1344 // Remove Reshape before FullyConnected when `keep_num_dims=false` and Reshape
1345 // does not alter the last dimension as FullyConnected will collapse all other
1346 // dimensions into a single dimension. For example,
1347 //
1348 //   %shape = arith.constant dense<[1, 128, 64]> : tensor<3xi32>
1349 //   %reshape = tfl.reshape(%input, %shape) // %input: tensor<128x64xf32>
1350 //   %fc = tfl.fully_connected(%reshape, %filter, %bias)
1351 //           {keep_num_dims = false, weights_format = "DEFAULT"}
1352 //
1353 // can be canonicalized to
1354 //
1355 //   %fc = tfl.fully_connected(%input, %filter, %bias)
1356 //           {keep_num_dims = false, weights_format = "DEFAULT"}
1357 struct RemoveReshapeBeforeFullyConnected
1358     : public OpRewritePattern<TFL::FullyConnectedOp> {
1359   using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
1360 
matchAndRewritemlir::TFL::__anonacf349380111::RemoveReshapeBeforeFullyConnected1361   LogicalResult matchAndRewrite(TFL::FullyConnectedOp fully_connected_op,
1362                                 PatternRewriter &) const override {
1363     auto input = fully_connected_op.input();
1364     auto input_ty = input.getType().dyn_cast<ShapedType>();
1365     auto output_ty = fully_connected_op.output()[0]
1366                          .getType()
1367                          .template dyn_cast<ShapedType>();
1368     if (!input_ty.hasStaticShape() ||
1369         fully_connected_op.weights_format() != "DEFAULT" ||
1370         fully_connected_op.keep_num_dims() || !output_ty.hasStaticShape() ||
1371         output_ty.getRank() != 2) {
1372       return failure();
1373     }
1374 
1375     auto reshape_op = input.getDefiningOp<TFL::ReshapeOp>();
1376     if (!reshape_op) return failure();
1377 
1378     // Check if the last dimension does not change after reshape.
1379     auto reshape_input = reshape_op.input();
1380     auto reshape_input_ty = reshape_input.getType().dyn_cast<ShapedType>();
1381     if (!reshape_input_ty.hasStaticShape() || input_ty.getRank() == 0 ||
1382         reshape_input_ty.getRank() == 0 ||
1383         input_ty.getDimSize(input_ty.getRank() - 1) !=
1384             reshape_input_ty.getDimSize(reshape_input_ty.getRank() - 1)) {
1385       return failure();
1386     }
1387 
1388     // Connect the input to the one of reshape.
1389     fully_connected_op.setOperand(0, reshape_input);
1390     return success();
1391   }
1392 };
1393 
1394 // Remove Reshape after FullyConnected when `keep_num_dims=false`, the Reshape
1395 // does not alter the last dimension and it restores the batch dimensions
1396 // collapsed by the FullyConnected op due to `keep_num_dims=false`. For example,
1397 //
1398 //   // %input: tensor<4x16x32xf32>
1399 //   %fc = tfl.fully_connected(%input, %filter, %bias)
1400 //           {keep_num_dims = false, weights_format = "DEFAULT"}
1401 //   %shape = arith.constant dense<[4, 16, 32]> : tensor<3xi32>
1402 //   %rs = tfl.reshape(%fc, %shape)
1403 //
1404 // can be canonicalized to
1405 //
1406 //   %fc = tfl.fully_connected(%input, %filter, %bias)
1407 //           {keep_num_dims = true, weights_format = "DEFAULT"}
1408 struct RemoveReshapeAfterFullyConnected
1409     : public OpRewritePattern<TFL::ReshapeOp> {
1410   using OpRewritePattern::OpRewritePattern;
1411 
matchAndRewritemlir::TFL::__anonacf349380111::RemoveReshapeAfterFullyConnected1412   LogicalResult matchAndRewrite(TFL::ReshapeOp reshape_op,
1413                                 PatternRewriter &rewriter) const override {
1414     auto fully_connected_op = llvm::dyn_cast_or_null<TFL::FullyConnectedOp>(
1415         reshape_op.input().getDefiningOp());
1416     if (!fully_connected_op || fully_connected_op.getNumResults() != 1 ||
1417         fully_connected_op.weights_format() != "DEFAULT" ||
1418         fully_connected_op.keep_num_dims())
1419       return failure();
1420     if (!reshape_op.input().hasOneUse()) return failure();
1421 
1422     auto input_shape = fully_connected_op.input().getType().cast<ShapedType>();
1423     auto output_shape = fully_connected_op.getType(0).cast<ShapedType>();
1424     auto reshape_shape = reshape_op.getType().cast<ShapedType>();
1425     if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() ||
1426         !reshape_shape.hasStaticShape())
1427       return failure();
1428 
1429     // Check that the reshape doesn't modify the last dimension and it restores
1430     // the input (batch) dimension with the exception of the feature (last)
1431     // dimension.
1432     if (output_shape.getShape().empty() || reshape_shape.getShape().empty() ||
1433         output_shape.getShape().back() != reshape_shape.getShape().back() ||
1434         input_shape.getShape().drop_back() !=
1435             reshape_shape.getShape().drop_back())
1436       return failure();
1437 
1438     llvm::SmallVector<Type, 1> output_type{reshape_op.getType()};
1439     rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
1440         reshape_op, output_type, /*input=*/fully_connected_op.input(),
1441         /*filter=*/fully_connected_op.filter(),
1442         /*bias=*/fully_connected_op.bias(),
1443         /*fused_activation_function=*/
1444         fully_connected_op.fused_activation_function(),
1445         /*weights_format=*/fully_connected_op.weights_format(),
1446         /*keep_num_dims=*/true,
1447         /*asymmetric_quantize_inputs=*/
1448         fully_connected_op.asymmetric_quantize_inputsAttr());
1449     return success();
1450   }
1451 };
1452 
1453 // Fuses Unpack with proceeding Concatenation to Reshape if output type has
1454 // static shape and activation function is none. For example:
1455 //
1456 //   // %input: tensor<1x3x2xf32>
1457 //   %unpack:3 = "tfl.unpack"(%input) {axis = 1 : i32, num = 3 : i32}
1458 //   %res = "tfl.concatenation"(%unpack#0, %unpack#1, %unpack#2)
1459 //        {axis = -1 : i32, fused_activation_function = "NONE"}
1460 //
1461 // can be optimized to
1462 //
1463 //   %cst = arith.constant dense<[1, 6]> : tensor<2xi32>
1464 //   %res = "tfl.reshape"(%input, %cst)
1465 struct FuseUnpackAndConcatToReshape
1466     : public OpRewritePattern<TFL::ConcatenationOp> {
1467   using OpRewritePattern::OpRewritePattern;
1468 
matchAndRewritemlir::TFL::__anonacf349380111::FuseUnpackAndConcatToReshape1469   LogicalResult matchAndRewrite(TFL::ConcatenationOp concat_op,
1470                                 PatternRewriter &rewriter) const override {
1471     if (concat_op.fused_activation_function() != "NONE") {
1472       return failure();
1473     }
1474 
1475     // Checks all operands come from the same unpack op.
1476     auto first_operand = concat_op.values().front();
1477     auto unpack_op =
1478         dyn_cast_or_null<TFL::UnpackOp>(first_operand.getDefiningOp());
1479     if (!unpack_op || unpack_op.getNumResults() != concat_op.getNumOperands()) {
1480       return failure();
1481     }
1482     for (auto &index_and_value : llvm::enumerate(concat_op.values())) {
1483       if (index_and_value.value() !=
1484           unpack_op.getResult(index_and_value.index())) {
1485         return failure();
1486       }
1487     }
1488 
1489     auto output_type = concat_op.getType().cast<ShapedType>();
1490     if (!output_type.hasStaticShape()) {
1491       return failure();
1492     }
1493 
1494     auto new_shape_array = output_type.getShape();
1495     // This is to workaround the unnecessary cast i64 -> i32.
1496     SmallVector<int32_t, 4> new_shape_array_i32;
1497     for (auto size : new_shape_array) {
1498       new_shape_array_i32.push_back(static_cast<int32_t>(size));
1499     }
1500     auto new_shape = rewriter.create<TFL::ConstOp>(
1501         concat_op.getLoc(),
1502         DenseIntElementsAttr::get(
1503             RankedTensorType::get(new_shape_array_i32.size(),
1504                                   rewriter.getIntegerType(32)),
1505             new_shape_array_i32));
1506 
1507     rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(concat_op, output_type,
1508                                                 unpack_op.input(), new_shape);
1509     return success();
1510   }
1511 };
1512 
1513 // Reduce the K of a TopKV2Op for the following case.
1514 //
1515 // values, indices = tfl.topkv2(%inputs, K)
1516 // %1 = tfl.slice(values, 0, k)
1517 // %2 = tfl.slice(indices,0, k)
1518 // .... (values and indices only used for %1 and %2)
1519 //
1520 // %1 or %2 can be absent. If values and indices are only used here,
1521 // this pattern can be replaced with (conceptually)
1522 //
1523 // %values, %indices = tfl.topkv2(%inputs, k)
1524 // replace all use of %1 with values
1525 // replace all use of %2 with indices
1526 //
1527 struct OptimizeTopK : public OpRewritePattern<TFL::TopKV2Op> {
1528   using OpRewritePattern::OpRewritePattern;
1529 
1530   // It computes the last dim k of slice size of value.user.
1531   // If value has no use then return 0.
ComputeSliceKmlir::TFL::__anonacf349380111::OptimizeTopK1532   llvm::Optional<int32_t> ComputeSliceK(Value value) const {
1533     if (value.use_empty()) return 0;
1534     auto slice_op =
1535         llvm::dyn_cast_or_null<TFL::SliceOp>(value.getUses().begin().getUser());
1536     // We only match for the case where value is used by SliceOp.
1537     if (!slice_op) return llvm::None;
1538     DenseElementsAttr begin;
1539     DenseElementsAttr size;
1540     if (!matchPattern(slice_op->getOperand(1), m_Constant(&begin)) ||
1541         !matchPattern(slice_op->getOperand(2), m_Constant(&size)))
1542       return llvm::None;
1543 
1544     // Check if "begin" is a zero tensor.
1545     for (auto begin_idx : begin.getValues<APInt>())
1546       if (begin_idx != 0) return llvm::None;
1547 
1548     // Check if "size" is equal to slice_op.input.shape except
1549     // for last dimension.
1550     // It can be done  by verifying the number of elements:
1551     // i.e., num_input/input_last_dim = num_result/k
1552     auto input_ty = value.getType().dyn_cast_or_null<ShapedType>();
1553     auto result_ty = slice_op.getType().dyn_cast<ShapedType>();
1554     if (!input_ty || !result_ty) return llvm::None;
1555     if (!input_ty.hasStaticShape() || !result_ty.hasStaticShape())
1556       return llvm::None;
1557     if (!input_ty.getRank() || !result_ty.getRank()) return llvm::None;
1558     int num_input = input_ty.getNumElements();
1559     int input_last_dim = input_ty.getShape().back();
1560     if (input_last_dim < 1) return llvm::None;
1561     int num_result = result_ty.getNumElements();
1562     auto size_last = *(--size.value_end<APInt>());
1563     int32_t k = size_last.getSExtValue();
1564     if (num_input / input_last_dim * k != num_result) return llvm::None;
1565     // We don't match sliceOp with last dim size = 0.
1566     if (!k) return llvm::None;
1567     return k;
1568   }
1569 
matchAndRewritemlir::TFL::__anonacf349380111::OptimizeTopK1570   LogicalResult matchAndRewrite(TFL::TopKV2Op op,
1571                                 PatternRewriter &rewriter) const override {
1572     auto values = op.values();
1573     auto indices = op.indices();
1574     // op.values() and op.indices() cannot be used more than once.
1575     if (!values.hasOneUse() && !values.use_empty()) return failure();
1576     if (!indices.hasOneUse() && !indices.use_empty()) return failure();
1577 
1578     auto k_values_or = ComputeSliceK(values);
1579     auto k_indices_or = ComputeSliceK(indices);
1580     if (!k_values_or.has_value() || !k_indices_or.has_value()) return failure();
1581     int32_t k_values = k_values_or.getValue();
1582     int32_t k_indices = k_indices_or.getValue();
1583     // We don't match two SliceOp with different sizes.
1584     if (k_values != k_indices && !values.use_empty() && !indices.use_empty())
1585       return failure();
1586 
1587     // Start replacing.
1588     auto k = !values.use_empty() ? k_values : k_indices;
1589     // Build scalar tensor k.
1590     auto k_ty = mlir::RankedTensorType::get({}, rewriter.getIntegerType(32));
1591     Value k_cst = rewriter.create<TFL::ConstOp>(
1592         op.getLoc(), DenseElementsAttr::get(k_ty, k));
1593     // Compute new result types.
1594     auto values_ty = values.getType().dyn_cast<ShapedType>();
1595     auto indices_ty = indices.getType().dyn_cast<ShapedType>();
1596     auto shape = std::vector<int64_t>();
1597     for (auto d : values_ty.getShape().drop_back()) {
1598       shape.push_back(d);
1599     }
1600     shape.push_back(static_cast<int64_t>(k));
1601     auto new_values_ty =
1602         mlir::RankedTensorType::get(shape, values_ty.getElementType());
1603     auto new_indices_ty =
1604         mlir::RankedTensorType::get(shape, indices_ty.getElementType());
1605     TFL::TopKV2Op top_k_op = rewriter.create<TFL::TopKV2Op>(
1606         op.getLoc(), new_values_ty, new_indices_ty, op->getOperand(0), k_cst);
1607 
1608     // Remove original ops (topk, Slice, Slice).
1609     if (!values.use_empty()) {
1610       auto values_slice_op = llvm::dyn_cast_or_null<TFL::SliceOp>(
1611           values.getUses().begin().getUser());
1612       values_slice_op.getResult().replaceAllUsesWith(top_k_op.values());
1613       values_slice_op.erase();
1614     }
1615     if (!indices.use_empty()) {
1616       auto indices_slice_op = llvm::dyn_cast_or_null<TFL::SliceOp>(
1617           indices.getUses().begin().getUser());
1618       indices_slice_op.getResult().replaceAllUsesWith(top_k_op.indices());
1619       indices_slice_op.erase();
1620     }
1621     op.erase();
1622     return success();
1623   }
1624 };
1625 
1626 using FuseBinaryOpToFollowingFullyConnected =
1627     FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
1628 using FuseBinaryOpToFollowingDepthwiseConv2D =
1629     FuseBinaryOpToFollowingAffineOp<DepthwiseConv2DOp>;
1630 using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp<Conv2DOp>;
1631 
1632 // Adds canonicalization patterns to the list of patterns.
AddCanonicalizationPatterns(MLIRContext * context,RewritePatternSet * patterns)1633 void AddCanonicalizationPatterns(MLIRContext *context,
1634                                  RewritePatternSet *patterns) {
1635   for (auto op : context->getRegisteredOperations())
1636     op.getCanonicalizationPatterns(*patterns, context);
1637 }
1638 
runOnOperation()1639 void OptimizePass::runOnOperation() {
1640   RewritePatternSet patterns(&getContext());
1641   auto *ctx = &getContext();
1642   auto func = getOperation();
1643 
1644   // Merge reshapes into fully connected ops before we start moving them past
1645   // binary ops.
1646   RewritePatternSet phase_0_patterns(&getContext());
1647   phase_0_patterns
1648       .add<RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected>(
1649           ctx);
1650   (void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns));
1651 
1652   // Potentially the binary ops might be fused together, like hard_swish, thus
1653   // we explore these potentially first and then fuse the binary ops with the
1654   // following ops in a second pattern match.
1655   TFL::populateWithGenerated(patterns);
1656   patterns.add<FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
1657                FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
1658                FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
1659                FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
1660                FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>>(ctx);
1661   if (this->enable_canonicalization_)
1662     AddCanonicalizationPatterns(ctx, &patterns);
1663   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1664 
1665   // Fuse the binary ops with the following ops.
1666   RewritePatternSet phase_2_patterns(&getContext());
1667   TFL::populateWithGenerated(phase_2_patterns);
1668   phase_2_patterns.add<
1669       ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
1670       ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
1671       FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
1672       FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
1673       FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
1674       FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
1675       FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
1676       FuseBinaryOpToFollowingConv2D, FuseBinaryOpToFollowingDepthwiseConv2D,
1677       FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
1678       FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp,
1679       RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected,
1680       FuseUnpackAndConcatToReshape, OptimizeTopK>(ctx);
1681   if (this->enable_canonicalization_)
1682     AddCanonicalizationPatterns(ctx, &phase_2_patterns);
1683   (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1684 }
1685 }  // namespace
1686 
1687 // Creates an instance of the TensorFlow Lite dialect Optimize pass.
CreateOptimizePass(bool enable_canonicalization)1688 std::unique_ptr<OperationPass<func::FuncOp>> CreateOptimizePass(
1689     bool enable_canonicalization) {
1690   return std::make_unique<OptimizePass>(enable_canonicalization);
1691 }
1692 
CreateOptimizePass()1693 std::unique_ptr<OperationPass<func::FuncOp>> CreateOptimizePass() {
1694   return std::make_unique<OptimizePass>();
1695 }
1696 
1697 }  // namespace TFL
1698 }  // namespace mlir
1699