1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This transformation pass takes operations in TensorFlowLite dialect and
17 // optimizes them to resulting operations in TensorFlowLite dialect.
18
19 #include <algorithm>
20 #include <climits>
21 #include <cstdint>
22 #include <functional>
23 #include <iterator>
24 #include <map>
25 #include <numeric>
26 #include <utility>
27
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/APInt.h"
30 #include "llvm/ADT/ArrayRef.h"
31 #include "llvm/ADT/None.h"
32 #include "llvm/ADT/Optional.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallSet.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/ADT/StringSwitch.h"
38 #include "llvm/Support/Casting.h"
39 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
40 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
41 #include "mlir/IR/Attributes.h" // from @llvm-project
42 #include "mlir/IR/Builders.h" // from @llvm-project
43 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
44 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
45 #include "mlir/IR/MLIRContext.h" // from @llvm-project
46 #include "mlir/IR/Matchers.h" // from @llvm-project
47 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
48 #include "mlir/IR/Value.h" // from @llvm-project
49 #include "mlir/Pass/Pass.h" // from @llvm-project
50 #include "mlir/Support/LLVM.h" // from @llvm-project
51 #include "mlir/Support/LogicalResult.h" // from @llvm-project
52 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
53 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
54 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
55 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
56 #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
57 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
58 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
59 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
60
61 namespace mlir {
62 namespace TFL {
63
64 //===----------------------------------------------------------------------===//
65 // The actual Optimize Pass.
66 namespace {
67 #define GEN_PASS_CLASSES
68 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
69
70 constexpr char kRelu[] = "RELU";
71 constexpr char kRelu6[] = "RELU6";
72 constexpr char kRelu1[] = "RELU_N1_TO_1";
73
L2NormalizeReduceAxis(Value sq_op,DenseElementsAttr axis)74 bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
75 if (axis.getNumElements() == 0) {
76 return false;
77 }
78 if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
79 *axis.getValues<int>().begin() ||
80 *axis.getValues<int>().begin() == -1) {
81 return true;
82 }
83 if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
84 return false;
85 }
86 auto shape = sq_op.getType().cast<ShapedType>();
87 SmallVector<int, 4> elems{axis.getValues<int>().begin(),
88 axis.getValues<int>().end()};
89 for (int i = 0; i < shape.getRank(); ++i) {
90 if (i != elems[i]) return false;
91 }
92 return true;
93 }
94
95 using ::llvm::cast;
96
97 // Optimize TFLite operations in functions.
98 class OptimizePass : public OptimizePassBase<OptimizePass> {
99 public:
100 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass)
101
102 OptimizePass() = default;
OptimizePass(const OptimizePass &)103 OptimizePass(const OptimizePass &) {}
OptimizePass(bool enable_canonicalization)104 explicit OptimizePass(bool enable_canonicalization) {
105 this->enable_canonicalization_ = enable_canonicalization;
106 }
107
108 void runOnOperation() override;
109 };
110
111 // Returns whether the given type `a` is broadcast-compatible with `b`.
IsBroadcastableElementsAttrAndType(Type a,Type b)112 bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
113 return OpTrait::util::getBroadcastedType(a, b) != Type();
114 }
115
116 // Returns whether the resultant type of any broadcastable operation with
117 // operands `a` and `b` matches `expected_output`. Returns false if `a` is not
118 // broadcast-compatible with `b`.
OperandsBroadcastToOutputType(Type a,Type b,Type expected_output)119 bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) {
120 Type output_element_type =
121 expected_output.cast<ShapedType>().getElementType();
122 Type broadcasted_type =
123 OpTrait::util::getBroadcastedType(a, b, output_element_type);
124 return broadcasted_type != Type() && broadcasted_type == expected_output;
125 }
126
127 // Returns whether if `type1` dimensions are the same as the ending dimensions
128 // of `type2`. This is more restricted than broadcastable.
IsTailOfShape(Type type1,Type type2)129 bool IsTailOfShape(Type type1, Type type2) {
130 auto tail_type = type1.dyn_cast<ShapedType>();
131 auto full_type = type2.dyn_cast<ShapedType>();
132 if (!tail_type || !full_type || !tail_type.hasRank() ||
133 !full_type.hasRank() || tail_type.getRank() > full_type.getRank())
134 return false;
135 auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
136 auto i2 = full_type.getShape().rbegin();
137 return std::equal(i1, e1, i2);
138 }
139
CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,const ArrayRef<int64_t> elements_shape,bool is_depthwise)140 bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef<int64_t> filter_shape,
141 const ArrayRef<int64_t> elements_shape,
142 bool is_depthwise) {
143 // Also, val tensor must be of rank 1 or 0 (scalar).
144 const auto elements_rank = elements_shape.size();
145 if (elements_rank != 1 && elements_rank != 0) {
146 return false;
147 }
148 auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back();
149 // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we
150 // can let binary op to broadcast elements.
151 if (elements_depth == 1) {
152 return true;
153 }
154
155 // In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise Conv.
156 // For conv:
157 // Check if last dimension in filter equals the first dimension
158 // For depthwise conv:
159 // Check if the first in filter dimension equals the first dimension.
160 if (filter_shape.empty() ||
161 (is_depthwise ? filter_shape.back() != elements_depth
162 : filter_shape[0] != elements_depth))
163 return false;
164 return true;
165 }
166
CanFuseConvOrDepthwiseConv(Value filter,Attribute val,bool is_depthwise)167 bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val,
168 bool is_depthwise) {
169 const auto elements = val.dyn_cast<DenseElementsAttr>();
170 if (!elements) {
171 return false;
172 }
173 const auto elements_shape = elements.getType().getShape();
174 const auto filter_shape = filter.getType().cast<ShapedType>().getShape();
175 return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape,
176 is_depthwise);
177 }
178
CanFuseConvOrDepthwiseConv(Attribute filter,Attribute val,bool is_depthwise)179 bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
180 bool is_depthwise) {
181 if (const auto elements = val.dyn_cast<DenseElementsAttr>()) {
182 if (const auto filter_elements = filter.dyn_cast<DenseElementsAttr>()) {
183 return CanFuseConvOrDepthwiseConvShapes(
184 filter_elements.getType().getShape(), elements.getType().getShape(),
185 is_depthwise);
186 }
187 }
188 return false;
189 }
190
191 // Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value
192 // of `indices` are from 0 to n-1, the output tensor are identical to the
193 // `params`.
CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,DenseIntElementsAttr indices,Type output_type)194 bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params,
195 DenseIntElementsAttr indices,
196 Type output_type) {
197 auto params_type = params.getType().dyn_cast<RankedTensorType>();
198 auto indices_type = indices.getType().dyn_cast<RankedTensorType>();
199 // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D
200 // `indices` means it gets the first row of `params`. As long as indices
201 // iterate the first row of `params`, the output is identical to input.
202 if (!params_type || !indices_type || indices_type.getRank() != 2 ||
203 indices_type.getDimSize(0) != params_type.getDimSize(0) ||
204 indices_type.getDimSize(1) != 1)
205 return false;
206
207 // Checks the `params_type` is equal to `output_type`. If not equal, we
208 // cannot replace the scatter_nd/gather_nd op with `params`.
209 if (params_type != output_type) return false;
210
211 // Checks the value in `indices` is from 0 to n-1.
212 int cur_value = 0;
213 for (const auto &v : indices.getValues<APInt>()) {
214 if (v.getSExtValue() != cur_value) return false;
215 ++cur_value;
216 }
217
218 return true;
219 }
220
221 // Returns true if we can eliminate the SliceOp. When the values of `begin` are
222 // all 0s and `size[i]` is equal to either -1 or `input.shape[i]`
223 // for each dim i, the output tensor is identical to `input`.
CanOptimizeIdentitySliceOp(Value input,Attribute begin,Attribute size)224 bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) {
225 // Checks if `begin` and `size` are i32 or i64.
226 auto begin_attr = begin.dyn_cast<DenseIntElementsAttr>();
227 auto size_attr = size.dyn_cast<DenseIntElementsAttr>();
228 if (!begin_attr || !size_attr) {
229 return false;
230 }
231
232 auto begin_elem_ty = begin_attr.getType().getElementType();
233 if (!begin_elem_ty.isInteger(32) && !begin_elem_ty.isInteger(64)) {
234 return false;
235 }
236 auto size_elem_ty = size_attr.getType().getElementType();
237 if (!size_elem_ty.isInteger(32) && !size_elem_ty.isInteger(64)) {
238 return false;
239 }
240
241 // Checks if `input` is ranked and its rank is equal to number of elements in
242 // `begin` and `size`.
243 auto input_ty = input.getType().cast<ShapedType>();
244 if (!input_ty.hasRank()) {
245 return false;
246 }
247
248 int64_t rank = input_ty.getRank();
249 if (rank != begin_attr.getNumElements() ||
250 rank != size_attr.getNumElements()) {
251 return false;
252 }
253
254 // Checks if `begin` is all 0s, and `size[i]` is equal to either -1 or
255 // `input.shape[i]`.
256 for (uint64_t i = 0; i < rank; ++i) {
257 if (begin_attr.getValues<APInt>()[i].getSExtValue() != 0) return false;
258 int64_t si = size_attr.getValues<APInt>()[i].getSExtValue();
259 if (si != -1 && si != input_ty.getDimSize(i)) return false;
260 }
261
262 return true;
263 }
264
265 // Expand Attribute 'a' to 4D with all 1s except 1 dimension.
266 // Which dimension depends on 'is_depthwise' is true or false.
ExpandTo4DForConvImpl(Attribute a,bool is_depthwise)267 ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) {
268 auto elements = a.dyn_cast<DenseElementsAttr>();
269 auto shape = elements.getType().getShape();
270 if (!shape.empty()) {
271 // Checks that elements are essentially 1d.
272 assert(elements.getNumElements() == shape.back());
273 }
274 std::vector<int64_t> shape_data = {1, 1, 1, 1};
275 const int vector_length = elements.getNumElements();
276 if (is_depthwise)
277 shape_data[3] = vector_length;
278 else
279 shape_data[0] = vector_length;
280 auto new_shape =
281 RankedTensorType::get(shape_data, elements.getType().getElementType());
282 return elements.reshape(new_shape);
283 }
284
ExpandTo4DForConv(Attribute a)285 ElementsAttr ExpandTo4DForConv(Attribute a) {
286 return ExpandTo4DForConvImpl(a, false);
287 }
288
ExpandTo4DForDepthwiseConv(Attribute a)289 ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
290 return ExpandTo4DForConvImpl(a, true);
291 }
292
RescaleQtype(Type input,Attribute factor)293 TypeAttr RescaleQtype(Type input, Attribute factor) {
294 return quant::RescaleQuantizedType(input, factor);
295 }
296
297 // Returns shape of a ranked tensor.
298 // Precondition: output_val's is ranked tensor.
GetShape(Value output_val)299 DenseElementsAttr GetShape(Value output_val) {
300 auto output_type = output_val.getType().cast<RankedTensorType>();
301 auto shape_vector = output_type.getShape();
302 std::vector<int32_t> shape;
303 shape.reserve(shape_vector.size());
304 for (auto shape_object : shape_vector) {
305 shape.push_back(shape_object);
306 }
307 return mlir::DenseElementsAttr::get(
308 RankedTensorType::get(
309 {static_cast<int>(shape.size())},
310 mlir::IntegerType::get(output_val.getContext(), 32)),
311 llvm::makeArrayRef(shape));
312 }
313
GetShapeStrippedType(TypeAttr type_attr)314 static Type GetShapeStrippedType(TypeAttr type_attr) {
315 auto type = type_attr.getValue();
316 auto shaped_type = type.dyn_cast<ShapedType>();
317 if (shaped_type) {
318 return shaped_type.getElementType();
319 } else {
320 return type;
321 }
322 }
323
324 // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in
325 // the specified `shape` and `false` otherwise.
ShapeMatchesReduceWithKeepAxes(Value input,const mlir::Attribute & axes,const mlir::Attribute & shape)326 static bool ShapeMatchesReduceWithKeepAxes(Value input,
327 const mlir::Attribute &axes,
328 const mlir::Attribute &shape) {
329 RankedTensorType type = input.getType().dyn_cast_or_null<RankedTensorType>();
330 if (!type) return false;
331
332 DenseIntElementsAttr axes_attr =
333 axes.dyn_cast_or_null<DenseIntElementsAttr>();
334 DenseIntElementsAttr shape_attr =
335 shape.dyn_cast_or_null<DenseIntElementsAttr>();
336 if (!axes_attr || !shape_attr) return false;
337
338 if (shape_attr.getNumElements() != type.getRank()) return false;
339
340 llvm::SmallSet<uint64_t, 4> axes_set;
341 for (auto a : axes_attr.getValues<APInt>()) {
342 axes_set.insert(a.getZExtValue());
343 }
344
345 auto type_shape = type.getShape();
346 for (uint64_t i = 0; i < type.getRank(); ++i) {
347 if (axes_set.contains(i)) {
348 if (shape_attr.getValues<APInt>()[i] != 1) return false;
349 } else {
350 if (shape_attr.getValues<APInt>()[i] != type_shape[i]) return false;
351 }
352 }
353 return true;
354 }
355
356 // Returns `true` if all the `axes` dimensions of `input` are 1.
AreInputDimensionsOneInAxes(Value input,const mlir::Attribute & axes)357 static bool AreInputDimensionsOneInAxes(Value input,
358 const mlir::Attribute &axes) {
359 RankedTensorType input_type =
360 input.getType().dyn_cast_or_null<RankedTensorType>();
361 if (!input_type) return false;
362 auto type_shape = input_type.getShape();
363
364 DenseIntElementsAttr axes_attr =
365 axes.dyn_cast_or_null<DenseIntElementsAttr>();
366 if (!axes_attr) return false;
367
368 for (auto a : axes_attr.getValues<APInt>()) {
369 int64_t axis = a.getSExtValue();
370 if (axis < 0) {
371 axis += type_shape.size();
372 }
373 if (axis < 0 || axis >= type_shape.size()) {
374 // `axis` is not a valid axis in input.
375 return false;
376 }
377 if (type_shape[axis] != 1) {
378 return false;
379 }
380 }
381
382 return true;
383 }
384
FloatValueEquals(const Attribute & attr,double value)385 static bool FloatValueEquals(const Attribute &attr, double value) {
386 auto fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
387 if (!fp_attr) return false;
388
389 if (fp_attr.isSplat()) {
390 return fp_attr.getSplatValue<APFloat>().isExactlyValue(value);
391 }
392 return llvm::all_of(fp_attr.getValues<APFloat>(), [value](const APFloat &f) {
393 return f.isExactlyValue(value);
394 });
395 }
396
397 // Returns true if the value's element type is F32.
IsF32Value(Value value)398 bool IsF32Value(Value value) {
399 return value.getType().cast<ShapedType>().getElementType().isF32();
400 }
401
402 // Returns the number of elements in attr if it is a DenseElementsAttr, 1
403 // otherwise, as an unranked int32 Attribute.
GetNumElementsOrOne(Attribute attr)404 Attribute GetNumElementsOrOne(Attribute attr) {
405 const auto dense_attr = attr.dyn_cast_or_null<DenseElementsAttr>();
406 int32_t num_elements = dense_attr ? dense_attr.getNumElements() : 1;
407
408 OpBuilder builder(attr.getContext());
409
410 return DenseIntElementsAttr::get(
411 RankedTensorType::get({}, builder.getI32Type()),
412 {llvm::APInt(32, num_elements, true)});
413 }
414
HasExactlyTwoElements(Attribute attr)415 bool HasExactlyTwoElements(Attribute attr) {
416 const auto values = attr.dyn_cast_or_null<ElementsAttr>();
417 if (!values) return false;
418 return values.getNumElements() == 2;
419 }
420
421 // Returns true if attr is a DenseIntElementsAttr with the last element equal 1.
IsLastElementEqualsOne(Attribute attr)422 bool IsLastElementEqualsOne(Attribute attr) {
423 const auto ints = attr.dyn_cast_or_null<DenseIntElementsAttr>();
424 if (!ints) return false;
425 if (ints.empty()) return false;
426 const auto last_element_index = ints.getNumElements() - 1;
427 const auto iterator = ints.value_begin<int>();
428 const int last_element = iterator[last_element_index];
429 return last_element == 1;
430 }
431
432 // Reshapes value to a given shape.
ReshapeValueDroppingLastDim(OpBuilder & builder,Value value,Attribute shape)433 Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value,
434 Attribute shape) {
435 // This function is always guarded with IsLastElementEqualsOne(), so we could
436 // cast safely here.
437 const auto old_shape = shape.cast<DenseIntElementsAttr>();
438 auto iterator = old_shape.value_begin<int>();
439 SmallVector<int, 4> new_shape;
440 SmallVector<int64_t, 4> new_shape_i64;
441 for (int i = 0; i < old_shape.size() - 1; ++i) {
442 new_shape.push_back(*iterator);
443 new_shape_i64.push_back(*iterator);
444 ++iterator;
445 }
446 return builder.create<ReshapeOp>(
447 value.getLoc(),
448 RankedTensorType::get(
449 new_shape_i64, value.getType().cast<ShapedType>().getElementType()),
450 value,
451 builder.create<arith::ConstantOp>(
452 value.getLoc(), DenseIntElementsAttr::get(
453 RankedTensorType::get({old_shape.size() - 1},
454 builder.getI32Type()),
455 new_shape)));
456 }
457
458 // Returns true if val has a static shape and the last dimension equals 1.
IsLastDimensionEqualOne(Value val)459 bool IsLastDimensionEqualOne(Value val) {
460 const auto val_type = val.getType().cast<ShapedType>();
461 if (!val_type.hasStaticShape()) return false;
462 const auto val_shape = val_type.getShape();
463 if (val_shape.empty()) return false;
464 const auto last_element = *val_shape.rbegin();
465 return last_element == 1;
466 }
467
468 // Returns true if attr is a DenseIntElementsAttr of int32 or int64 values or an
469 // incrementing sequence from 0 to N-1.
470 //
471 // If such a value is used in an Equal operator, it can be replaced with OneHot.
IsOneHotIndexAttribute(Attribute attr)472 bool IsOneHotIndexAttribute(Attribute attr) {
473 const auto dense_attr = attr.dyn_cast_or_null<DenseIntElementsAttr>();
474 if (!dense_attr) {
475 return false;
476 }
477 auto index_type = dense_attr.getType();
478 const auto index_elem_bits = index_type.getElementTypeBitWidth();
479 if (index_elem_bits != 32 && index_elem_bits != 64) {
480 return false;
481 }
482 if (index_type.getRank() != 1) {
483 return false;
484 }
485 const auto elems = dense_attr.value_begin<APInt>();
486 for (int i = 0; i < dense_attr.getNumElements(); ++i) {
487 if (i != elems[i]) {
488 return false;
489 }
490 }
491 return true;
492 }
493
494 // Creates FullyConnected op from params and returns the output.
GetFcOutput(OpBuilder * builder,::mlir::Operation::result_range result,Value input,Value filter,Value bias,StringAttr fused_activation_function,StringAttr weights_format,BoolAttr keep_num_dims,BoolAttr asymmetric_quantize_inputs)495 mlir::Value GetFcOutput(OpBuilder *builder,
496 ::mlir::Operation::result_range result, Value input,
497 Value filter, Value bias,
498 StringAttr fused_activation_function,
499 StringAttr weights_format, BoolAttr keep_num_dims,
500 BoolAttr asymmetric_quantize_inputs) {
501 auto fc_op = builder->create<FullyConnectedOp>(
502 result[0].getLoc(), result.getTypes(), input, filter, bias,
503 fused_activation_function, weights_format, keep_num_dims,
504 asymmetric_quantize_inputs);
505 return fc_op->getResult(0);
506 }
507
508 // Returns true if 'value' represents a const ElementsAttr with all values
509 // equals to 0.0.
AllValuesAreZero(mlir::Value value)510 bool AllValuesAreZero(mlir::Value value) {
511 if (!value) return false;
512 DenseElementsAttr vals;
513 if (!matchPattern(value, m_Constant(&vals))) return false;
514 for (auto elem : vals.getValues<float>())
515 if (elem != 0.0f) return false;
516 return true;
517 }
518
519 // Converts an Attribute with a single value of float or integral type to an
520 // Attribute holding a single value of float type. If attr has no elements, the
521 // result is 0.0f.
ConvertSingleElementAttrToFloatAttr(Attribute attr)522 Attribute ConvertSingleElementAttrToFloatAttr(Attribute attr) {
523 const auto dense_fp_attr = attr.dyn_cast_or_null<DenseFPElementsAttr>();
524 if (dense_fp_attr) {
525 // Already float => return
526 return dense_fp_attr;
527 }
528
529 OpBuilder builder(attr.getContext());
530
531 const auto dense_int_attr = attr.dyn_cast<DenseIntElementsAttr>();
532 const auto int_values = dense_int_attr.getValues<APInt>();
533 float float_val = 0.0f;
534 if (!int_values.empty()) {
535 const APInt apint_val = *int_values.begin();
536 if (dense_int_attr.getType().getElementType().isSignedInteger()) {
537 // Get the sign-extended value (=>int64) if the type is signed.
538 float_val = apint_val.getSExtValue();
539 } else {
540 // Get the zero-extended value (=>uint64) if unsigned or signless.
541 float_val = apint_val.getZExtValue();
542 }
543 }
544 return DenseFPElementsAttr::get(
545 RankedTensorType::get({}, builder.getF32Type()),
546 {llvm::APFloat(float_val)});
547 }
548
549 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
550
551 // Fuse Add with proceeding FullyConnected.
552 // TODO(b/136285429): Move to tablegen when variadic is supported
553 struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
554 using OpRewritePattern<TFL::AddOp>::OpRewritePattern;
555
matchAndRewritemlir::TFL::__anonacf349380111::FuseFullyConnectedAndAdd556 LogicalResult matchAndRewrite(TFL::AddOp add_op,
557 PatternRewriter &rewriter) const override {
558 // Match Add.
559 DenseElementsAttr added_value;
560 Value constant_val = add_op.rhs();
561 if (!matchPattern(constant_val, m_Constant(&added_value))) return failure();
562
563 // Match Fully Connected.
564 auto fc_op =
565 dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
566 if (!fc_op) return failure();
567
568 // Check if the constant RHS is either 0D (scalar), or a 1D with
569 // `{num_channels}` shape.
570 auto constant_val_type = constant_val.getType().cast<TensorType>();
571
572 // In TFLite FullyConnect definition, bias must be a 1D tensor where
573 // the number of elements is equal to the number of channels.
574 // If it's not 1D or 0D (which can be broadcasted to 1D), reject the
575 // matching.
576 bool is_scalar_rhs = false;
577 if (constant_val_type.getRank() == 0) {
578 is_scalar_rhs = true;
579 } else if (constant_val_type.getRank() != 1) {
580 return failure();
581 }
582
583 Value filter = fc_op.filter();
584 Value bias = fc_op.bias();
585 ElementsAttr bias_value;
586 const bool is_none_bias = bias.getType().isa<NoneType>();
587 if (fc_op.fused_activation_function() != "NONE") return failure();
588
589 if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
590 return failure();
591
592 // Rewrite
593 if (is_none_bias) {
594 if (is_scalar_rhs) {
595 // If the `constant_val` is scalar, we must the shape of filter
596 // to properly broadcast the scalar to `{num_channels}` shape.
597
598 // Get the number of channels if possible.
599 auto filter_type = filter.getType().dyn_cast<RankedTensorType>();
600 // Filter must be a `2D` tensor with `{num_channels, num_features}`
601 // shape. The following check is rejecting unknown rank (-1).
602 if (filter_type == nullptr || filter_type.getRank() != 2) {
603 return failure();
604 }
605 int num_channels = filter_type.getShape()[0];
606
607 // Create a zero tensor with shape {num_channels}, and the type need to
608 // be the same as constant_val.
609 // This is a way to gracefully handle scalar tensor. The Add will always
610 // be constant-folded away regardless if `constant_val` is a scalar or
611 // not.
612 RankedTensorType type = RankedTensorType::get(
613 {num_channels}, constant_val_type.getElementType());
614 auto attr = rewriter.getZeroAttr(type);
615 bias = rewriter.create<arith::ConstantOp>(add_op.getLoc(), type, attr);
616 auto none_af = rewriter.getStringAttr("NONE");
617 bias =
618 rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
619 .output();
620 } else {
621 // If there no pre-existing bias and the `constant_val` is 1D, simply
622 // use `constant_val` as bias.
623 bias = constant_val;
624 }
625 } else {
626 auto none_af = rewriter.getStringAttr("NONE");
627 bias =
628 rewriter.create<AddOp>(add_op.getLoc(), bias, constant_val, none_af)
629 .output();
630 }
631
632 auto fc = rewriter.create<TFL::FullyConnectedOp>(
633 FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), add_op.getLoc()}),
634 add_op.getType(),
635 /*input=*/fc_op.input(),
636 /*filter=*/filter,
637 /*bias=*/bias,
638 /*fused_activation_function=*/
639 rewriter.getStringAttr(add_op.fused_activation_function()),
640 /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
641 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()),
642 /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
643 rewriter.replaceOp(add_op, fc.output());
644
645 return success();
646 }
647 };
648
649 // Replace ..
650 // FC(Add(lhs, rhs), filter, bias)
651 // .. with ..
652 // FC(lhs, filter, FC(rhs, filter, bias))
653 // .. if rhs, filter, and bias are all constants.
654 // The second FC will be constant folded to a single vector.
655 // TODO(b/136285429): Move to tablegen when variadic is supported
656 struct FuseAddAndFullyConnected
657 : public OpRewritePattern<TFL::FullyConnectedOp> {
658 using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
659
matchAndRewritemlir::TFL::__anonacf349380111::FuseAddAndFullyConnected660 LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
661 PatternRewriter &rewriter) const override {
662 // This only works with default format.
663 if (fc_op.weights_format() != "DEFAULT") return failure();
664
665 // Match Add.
666 auto add_op = dyn_cast_or_null<TFL::AddOp>(fc_op.input().getDefiningOp());
667 if (!add_op) return failure();
668 if (add_op.fused_activation_function() != "NONE") return failure();
669
670 // Don't match adds where the added constant is not 1D.
671 {
672 auto addend_shape = add_op.rhs().getType().cast<ShapedType>();
673 if (!addend_shape.hasStaticShape()) return failure();
674 if (addend_shape.getShape().size() != 1) return failure();
675 }
676
677 // Calculate new bias. Generate a new FC; it will be constant folded.
678 auto old_bias = fc_op.bias();
679 if (!old_bias || old_bias.getType().isa<NoneType>()) {
680 // TODO(b/180752069): Figure out new bias' type when old bias is empty.
681 return failure();
682 }
683
684 // The FC relies on constant folding, which is implemented on F32. Checks
685 // types to be F32.
686 {
687 if (!IsF32Value(add_op.rhs()) || !IsF32Value(fc_op.filter()) ||
688 !IsF32Value(old_bias))
689 return failure();
690 }
691
692 auto new_bias = rewriter.create<TFL::FullyConnectedOp>(
693 fc_op.getLoc(), old_bias.getType(),
694 /*input=*/add_op.rhs(),
695 /*filter=*/fc_op.filter(),
696 /*bias=*/old_bias,
697 /*fused_activation_function=*/rewriter.getStringAttr("NONE"),
698 /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
699 /*keep_num_dims=*/rewriter.getBoolAttr(true),
700 /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
701
702 // Create the updated FC.
703 auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
704 FusedLoc::get(add_op.getContext(), {add_op.getLoc(), fc_op.getLoc()}),
705 fc_op.output().getTypes(),
706 /*input=*/add_op.lhs(),
707 /*filter=*/fc_op.filter(),
708 /*bias=*/*new_bias.output().begin(),
709 /*fused_activation_function=*/
710 rewriter.getStringAttr(fc_op.fused_activation_function()),
711 /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
712 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()),
713 /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
714 rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
715
716 return success();
717 }
718 };
719
720 // Replace ..
721 // FC(Mul(lhs, rhs), filter, bias)
722 // .. with ..
723 // FC(lhs, Mul(filter, rhs), bias)
724 // .. if rhs, filter, and bias are all constants.
725 // The generated Mul will be constant folded to a single matrix.
726 struct FuseMulAndFullyConnected
727 : public OpRewritePattern<TFL::FullyConnectedOp> {
728 using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
729
matchAndRewritemlir::TFL::__anonacf349380111::FuseMulAndFullyConnected730 LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op,
731 PatternRewriter &rewriter) const override {
732 // This only works with default format.
733 if (fc_op.weights_format() != "DEFAULT") return failure();
734
735 // Match Mul.
736 auto mul_op = dyn_cast_or_null<TFL::MulOp>(fc_op.input().getDefiningOp());
737 if (!mul_op) return failure();
738 if (mul_op.fused_activation_function() != "NONE") return failure();
739
740 // Don't match muls where the multiplier constant is not 1D.
741 {
742 auto multiplier_shape = mul_op.rhs().getType().cast<ShapedType>();
743 if (!multiplier_shape.hasStaticShape()) return failure();
744 if (multiplier_shape.getShape().size() != 1) return failure();
745 }
746
747 // We rely on constant folding, implemented only for F32. Check types.
748 if (!IsF32Value(mul_op.rhs()) || !IsF32Value(fc_op.filter())) {
749 return failure();
750 }
751
752 auto location =
753 FusedLoc::get(mul_op.getContext(), {mul_op.getLoc(), fc_op.getLoc()});
754
755 auto new_filter = rewriter.create<TFL::MulOp>(
756 location,
757 /*lhs=*/fc_op.filter(),
758 /*rhs=*/mul_op.rhs(),
759 /*fused_activation_function=*/rewriter.getStringAttr("NONE"));
760 // Create the updated FC.
761 auto new_fc = rewriter.create<TFL::FullyConnectedOp>(
762 location, fc_op.output().getTypes(),
763 /*input=*/mul_op.lhs(),
764 /*filter=*/new_filter,
765 /*bias=*/fc_op.bias(),
766 /*fused_activation_function=*/
767 rewriter.getStringAttr(fc_op.fused_activation_function()),
768 /*weights_format=*/rewriter.getStringAttr("DEFAULT"),
769 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()),
770 /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
771 rewriter.replaceOp(fc_op.getOperation(), new_fc.output());
772
773 return success();
774 }
775 };
776
777 // TODO(b/136285429): Move to tablegen when variadic is supported.
778 template <typename ReluXOp, char const *Act>
779 struct FuseFullyConnectedAndReluX : public OpRewritePattern<ReluXOp> {
780 using OpRewritePattern<ReluXOp>::OpRewritePattern;
781
matchAndRewritemlir::TFL::__anonacf349380111::FuseFullyConnectedAndReluX782 LogicalResult matchAndRewrite(ReluXOp relu_op,
783 PatternRewriter &rewriter) const override {
784 Operation *input = relu_op.getOperand().getDefiningOp();
785 if (!isa_and_nonnull<FullyConnectedOp>(input)) return failure();
786 auto fully_connected_op = cast<FullyConnectedOp>(input);
787 if (fully_connected_op.fused_activation_function() != "NONE")
788 return failure();
789
790 auto new_activation_func = rewriter.getStringAttr(Act);
791 auto new_weights_format =
792 rewriter.getStringAttr(fully_connected_op.weights_format());
793 auto new_keep_num_dims =
794 rewriter.getBoolAttr(fully_connected_op.keep_num_dims());
795 auto fc = rewriter.create<FullyConnectedOp>(
796 FusedLoc::get(relu_op.getContext(),
797 {fully_connected_op.getLoc(), relu_op.getLoc()}),
798 relu_op.getType(), /*input=*/fully_connected_op.input(),
799 /*filter=*/fully_connected_op.filter(),
800 /*bias=*/fully_connected_op.bias(),
801 /*fused_activation_function=*/new_activation_func,
802 /*weights_format=*/new_weights_format,
803 /*keep_num_dims=*/new_keep_num_dims,
804 /*asymmetric_quantize_inputs=*/
805 fully_connected_op.asymmetric_quantize_inputsAttr());
806 rewriter.replaceOp(relu_op, fc.output());
807
808 return success();
809 }
810 };
811
812 // Fuse Mul with proceeding FullyConnected.
813 // TODO(b/136285429): Move to tablegen when variadic is supported
814 struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
815 using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
816
matchAndRewritemlir::TFL::__anonacf349380111::FuseFullyConnectedAndMul817 LogicalResult matchAndRewrite(TFL::MulOp mul_op,
818 PatternRewriter &rewriter) const override {
819 // If we are broadcasting on the lhs then don't fold the multiply as it
820 // would increase the amount of compute done by the fully connected op.
821 if (mul_op.lhs().getType() != mul_op.getType()) return failure();
822
823 // Mul.
824 DenseElementsAttr cst;
825 Value constant_val = mul_op.rhs();
826 if (!matchPattern(constant_val, m_Constant(&cst))) return failure();
827
828 // Fully Connected.
829 auto fc_op =
830 dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
831 if (!fc_op) return failure();
832 Value filter = fc_op.filter();
833 Value bias = fc_op.bias();
834 ElementsAttr cst_tmp;
835 if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
836 if (!bias.getType().isa<NoneType>() &&
837 !matchPattern(bias, m_Constant(&cst_tmp)))
838 return failure();
839 if (fc_op.fused_activation_function() != "NONE") return failure();
840
841 // Only fuse multiplier if all dimensions other than the depth dimension
842 // are equal to 1 since otherwise
843 // `matmul(x, filter) * cst != matmul(x, filter * cst)`
844 // even if `filter` and `cst` are be broadcastable.
845 auto shape = cst.getType().getShape();
846 if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
847
848 int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
849 // Expand and transpose the multiplier since weights are using the
850 // OHWI data format in TFLite.
851 int64_t normalized_shape[2] = {element_size, 1};
852 auto new_cst = cst.reshape(RankedTensorType::get(
853 normalized_shape, cst.getType().getElementType()));
854 Type new_type = new_cst.getType();
855 if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
856 return failure();
857 }
858
859 auto new_op =
860 rewriter.create<arith::ConstantOp>(mul_op.getLoc(), new_type, new_cst);
861 Value new_const_val = new_op.getResult();
862
863 // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
864 // TF::MulOp is used to fold the constant.
865 // TODO(b/139192933): switch to the TFL constant folding
866 auto new_filter =
867 rewriter.create<TF::MulOp>(mul_op.getLoc(), filter, new_const_val).z();
868 // If bias isn't None, it needs to be multiplied as well.
869 if (!bias.getType().isa<NoneType>()) {
870 bias =
871 rewriter.create<TF::MulOp>(mul_op.getLoc(), bias, constant_val).z();
872 }
873
874 auto fc = rewriter.create<TFL::FullyConnectedOp>(
875 FusedLoc::get(fc_op.getContext(), {fc_op.getLoc(), mul_op.getLoc()}),
876 mul_op.getType(),
877 /*input=*/fc_op.input(),
878 /*filter=*/new_filter,
879 /*bias=*/bias,
880 /*fused_activation_function=*/
881 rewriter.getStringAttr(mul_op.fused_activation_function()),
882 /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
883 /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()),
884 /*asymmetric_quantize_inputs=*/fc_op.asymmetric_quantize_inputsAttr());
885 rewriter.replaceOp(mul_op, fc.output());
886
887 return success();
888 }
889 };
890
891 // Fuse Mul with proceeding Affine ops. This is an C++ implementation of the
892 // following table gen implementation, which doesn't derived the result type of
893 // the TFL_DequantizeOp.
894 // def : Pat<(TFL_MulOp (TFL_Conv2DOp:$conv_output $input,
895 // (TFL_DequantizeOp (TFL_QuantizeOp
896 // (Arith_ConstantOp F32ElementsAttr:$filter),
897 // $qtype)),
898 // (Arith_ConstantOp F32ElementsAttr:$bias),
899 // $h_factor, $w_factor, TFL_AF_None,
900 // $padding, $stride_h, $stride_w),
901 // (Arith_ConstantOp F32ElementsAttr:$value), $act_fn),
902 // (TFL_Conv2DOp $input,
903 // (TFL_DequantizeOp (TFL_QuantizeOp
904 // (TFL_MulOp (Arith_ConstantOp $filter),
905 // (Arith_ConstantOp (ExpandTo4DForConv
906 // $value)),
907 // TFL_AF_None),
908 // (RescaleQtype $qtype, $value))),
909 // (TFL_MulOp (Arith_ConstantOp $bias), (Arith_ConstantOp
910 // $value),
911 // TFL_AF_None),
912 // $h_factor, $w_factor, $act_fn,
913 // $padding, $stride_h, $stride_w),
914 // [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
915 // (HasOneUse $conv_output),
916 // (IsPerAxisQuantization $qtype), // per-axis quantization
917 // ]>;
918 template <typename AffineOpType>
919 struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern<TFL::MulOp> {
920 using OpRewritePattern<TFL::MulOp>::OpRewritePattern;
921
matchAndRewritemlir::TFL::__anonacf349380111::FuseAffinOpAndMulWithQDQs922 LogicalResult matchAndRewrite(TFL::MulOp mul_op,
923 PatternRewriter &rewriter) const override {
924 // Mul. Required 1-D rhs for batch normalization.
925 DenseElementsAttr gamma_cst;
926 Value gamma = mul_op.rhs();
927 if (!matchPattern(gamma, m_Constant(&gamma_cst))) return failure();
928 if (gamma_cst.getType().getRank() != 1) return failure();
929
930 // Affine op
931 Operation *mul_op_lhs = mul_op.lhs().getDefiningOp();
932 auto fc_op = dyn_cast_or_null<AffineOpType>(mul_op_lhs);
933 if (!fc_op) return failure();
934 Value filter = fc_op.filter();
935 Value bias = fc_op.bias();
936
937 // QDQs
938 auto dq_op = dyn_cast_or_null<TFL::DequantizeOp>(filter.getDefiningOp());
939 if (!dq_op) return failure();
940 auto q_op =
941 dyn_cast_or_null<TFL::QuantizeOp>(dq_op.input().getDefiningOp());
942 if (!q_op) return failure();
943 filter = q_op.input();
944
945 // weight constant
946 ElementsAttr cst_tmp;
947 if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure();
948 if (!bias.getType().isa<NoneType>() &&
949 !matchPattern(bias, m_Constant(&cst_tmp)))
950 return failure();
951 if (fc_op.fused_activation_function() != "NONE") return failure();
952
953 // Broadcast the constant operand of Mul if it isn't compatible to the
954 // filter input. We only support broadcasting the operand along the depth
955 // dimension, when the operand's depth is 1.
956 rewriter.setInsertionPoint(q_op);
957 Location loc = fc_op.getLoc();
958 Value broadcasted_gamma;
959 if (isa<TFL::Conv2DOp>(mul_op_lhs)) {
960 auto mul_rhs = ExpandTo4DForConv(gamma_cst);
961 broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
962 } else if (isa<TFL::DepthwiseConv2DOp>(mul_op_lhs)) {
963 auto mul_rhs = ExpandTo4DForDepthwiseConv(gamma_cst);
964 broadcasted_gamma = rewriter.create<ConstOp>(loc, mul_rhs);
965 } else {
966 return failure();
967 }
968
969 // Make sure that the fused bias will be a 1D tensor.
970 auto gamma_shape = gamma.getType().cast<ShapedType>();
971 if (!gamma_shape.hasRank() || gamma_shape.getRank() != 1) {
972 return failure();
973 }
974
975 // Rewrite filter constant. Since the folder of TFL::MulOp couldn't
976 // broadcast the operands, TF::MulOp is used to fold the constant.
977 auto new_filter =
978 rewriter.create<TF::MulOp>(loc, filter, broadcasted_gamma).z();
979 // Update the scale in the quantize op.
980 auto new_qtype = RescaleQtype(q_op.qtype(), gamma_cst);
981 if (!new_qtype) return failure();
982 rewriter.replaceOpWithNewOp<TFL::QuantizeOp>(q_op, new_qtype.getValue(),
983 new_filter, new_qtype);
984
985 // If bias isn't None, it needs to be multiplied as well.
986 if (!bias.getType().isa<NoneType>()) {
987 rewriter.setInsertionPoint(fc_op);
988 auto new_bias = rewriter.create<TF::MulOp>(loc, bias, gamma);
989 fc_op.getOperation()->replaceUsesOfWith(bias, new_bias);
990 }
991
992 // Remove the tailing mul op.
993 mul_op.replaceAllUsesWith(fc_op.getResult());
994 return success();
995 }
996 };
997
998 using FuseConv2DAndMulWithQDQs = FuseAffinOpAndMulWithQDQs<TFL::Conv2DOp>;
999 using FuseDepthwiseConv2DAndMulWithQDQs =
1000 FuseAffinOpAndMulWithQDQs<TFL::DepthwiseConv2DOp>;
1001
1002 // Fuse Binary Op with following Affine operation.
1003 template <typename AffineOpType>
1004 struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
1005 using OpRewritePattern<AffineOpType>::OpRewritePattern;
1006
matchAndRewritemlir::TFL::__anonacf349380111::FuseBinaryOpToFollowingAffineOp1007 LogicalResult matchAndRewrite(AffineOpType fc_op,
1008 PatternRewriter &rewriter) const override {
1009 // Binary op.
1010 Operation *binary_op = fc_op.input().getDefiningOp();
1011 if (!binary_op || binary_op->getNumOperands() != 2) return failure();
1012 // We only handle the cases the RHS is a scalar.
1013 // TODO(fengliuai): Currently the canonicalizer pass couldn't guarantee that
1014 // the constant operands are on the RHS, we need to consider LHS constant
1015 // operand if necessary.
1016 DenseFPElementsAttr cst;
1017 if (!matchPattern(binary_op->getOperand(1), m_Constant(&cst)))
1018 return failure();
1019 if (cst.getNumElements() != 1) return failure();
1020 APFloat cst_value = *cst.value_begin<APFloat>();
1021
1022 // Affine op.
1023 Value filter = fc_op.filter();
1024 Value bias = fc_op.bias();
1025 DenseFPElementsAttr filter_cst, bias_cst;
1026 if (!matchPattern(filter, m_Constant(&filter_cst))) {
1027 // The filter maybe quantized, then we should set it to the real constant.
1028 auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
1029 if (!dq) return failure();
1030 auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
1031 if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
1032 return failure();
1033 }
1034 filter = q.input();
1035 }
1036 if (!bias.getType().isa<NoneType>() &&
1037 !matchPattern(bias, m_Constant(&bias_cst)))
1038 return failure();
1039 auto binary_op_activation_func =
1040 binary_op->template getAttrOfType<StringAttr>(
1041 "fused_activation_function");
1042 if (!binary_op_activation_func ||
1043 binary_op_activation_func.getValue() != "NONE")
1044 return failure();
1045 ShapedType filter_type = filter_cst.getType();
1046
1047 if (llvm::isa<AddOp, SubOp>(binary_op)) {
1048 auto padding = fc_op->template getAttrOfType<StringAttr>("padding");
1049 if (padding && padding.getValue() != "VALID") return failure();
1050
1051 // The fusion of add/sub is actually applying the following
1052 // transformation:
1053 // w * (x + c) + b => w * x + (w * c + b)
1054 // so we have to update the bias.
1055 if (llvm::isa<SubOp>(binary_op)) cst_value.changeSign();
1056
1057 auto bias_and_slice =
1058 GetBiasDimAndSliceSize(filter_type.getShape(), fc_op);
1059 int64_t bias_size = bias_and_slice.first;
1060 int64_t slice_size = bias_and_slice.second;
1061 ShapedType new_bias_type =
1062 RankedTensorType::get({bias_size}, filter_type.getElementType());
1063
1064 // The new bias should be a 1-D tensor with length equals to the bias
1065 // dimension of the weight.
1066 SmallVector<APFloat, 4> new_bias_values;
1067 if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
1068 new_bias_values.resize(bias_size,
1069 APFloat::getZero(cst_value.getSemantics()));
1070 } else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
1071 new_bias_values.resize(bias_size, *bias_cst.value_begin<APFloat>());
1072 } else if (bias_cst.getNumElements() == bias_size) { // 1-d bias, copy it
1073 new_bias_values.insert(new_bias_values.begin(),
1074 bias_cst.value_begin<APFloat>(),
1075 bias_cst.value_end<APFloat>());
1076 } else {
1077 return failure();
1078 }
1079
1080 int64_t flatten_index = 0;
1081 for (auto fp_it = filter_cst.value_begin<APFloat>(),
1082 fp_end = filter_cst.value_end<APFloat>();
1083 fp_it != fp_end; ++fp_it) {
1084 int bias_index = (flatten_index++ / slice_size) % bias_size;
1085
1086 new_bias_values[bias_index] =
1087 new_bias_values[bias_index] + *fp_it * cst_value;
1088 }
1089 auto new_bias = DenseFPElementsAttr::get(new_bias_type, new_bias_values);
1090 auto new_bias_op =
1091 rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
1092 fc_op.setOperand(0, binary_op->getOperand(0));
1093 fc_op.setOperand(2, new_bias_op);
1094 } else if (llvm::isa<MulOp, DivOp>(binary_op)) {
1095 // The fusion of mul/div is actually applying the following
1096 // transformation:
1097 // w * (x ' c) + b => (w ' c) x + b
1098 // so we have to update the weight.
1099 bool is_mul = llvm::isa<MulOp>(binary_op);
1100 auto new_filter =
1101 filter_cst.mapValues(filter_type.getElementType(), [&](APFloat it) {
1102 return (is_mul ? it * cst_value : it / cst_value).bitcastToAPInt();
1103 });
1104 // We recreate the constant op in case it is shared by the other ops. This
1105 // might increase the model size.
1106 auto new_filter_op = rewriter.create<ConstOp>(
1107 fc_op.getLoc(), filter.getType(), new_filter);
1108 fc_op.setOperand(0, binary_op->getOperand(0));
1109 if (fc_op.filter() != filter) {
1110 // This filter goes through quantize and dequantize ops. Then we just
1111 // need to update the weight to the quantize op.
1112 filter.replaceAllUsesWith(new_filter_op);
1113 } else {
1114 // This filter doesn't go through quantize and dequantize ops, Then
1115 // we update the weight of the affine op directly.
1116 fc_op.setOperand(1, new_filter_op);
1117 }
1118 } else {
1119 return failure();
1120 }
1121 return success();
1122 }
1123
1124 private:
1125 // Returns the dimension length of the channel dimension and also the slide
1126 // size by each position in the channel dimension accordingly. tfl.conv2d and
1127 // tfl.fully_connected has heading channel dimension, but tfl.depthwise_conv2d
1128 // has tailing channel dimension. This function is to provide a utility to
1129 // create the above information from the op property.
GetBiasDimAndSliceSizemlir::TFL::__anonacf349380111::FuseBinaryOpToFollowingAffineOp1130 static std::pair<int64_t, int64_t> GetBiasDimAndSliceSize(
1131 ArrayRef<int64_t> filter_shape, AffineOpType op) {
1132 // Channel dimension index is specified as op property
1133 auto channel_index_iter = filter_shape.begin();
1134 std::advance(channel_index_iter, op.GetChannelDimIndex());
1135 // The slide size is the size of the data in higher dimensions.
1136 int64_t slice_size =
1137 std::accumulate(std::next(channel_index_iter), filter_shape.end(), 1,
1138 std::multiplies<int64_t>());
1139 return {*channel_index_iter, slice_size};
1140 }
1141 };
1142
1143 // If the operand to a broadcastable op is a splat constant, try to replace it
1144 // with a 0-d constant, e.g. before this optimization,
1145 // %cst = arith.constant dense<1.0> : tensor<16x16x4xf32>
1146 // %0 = "tfl.conv_2d"...
1147 // %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
1148 // After this optimization:
1149 // %cst = arith.constant dense<1.0> : tensor<f32>
1150 // %0 = "tfl.conv_2d"...
1151 // %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
1152 // This pattern can enable more fusing opportunities when the binary op is
1153 // following conv ops.
1154 template <typename BinaryOpType>
1155 struct ScalarizeSplatConstantForBroadcastableOps
1156 : public OpRewritePattern<BinaryOpType> {
1157 using OpRewritePattern<BinaryOpType>::OpRewritePattern;
1158
matchAndRewritemlir::TFL::__anonacf349380111::ScalarizeSplatConstantForBroadcastableOps1159 LogicalResult matchAndRewrite(BinaryOpType binary_op,
1160 PatternRewriter &rewriter) const override {
1161 DenseElementsAttr splat_elements_attr;
1162 if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
1163 return failure();
1164 }
1165
1166 constexpr int kSplatOperandIndex = 1;
1167 auto result_type =
1168 binary_op.getResult().getType().template cast<ShapedType>();
1169 mlir::Value non_splat_operand =
1170 binary_op.getOperand(1 - kSplatOperandIndex);
1171 auto non_splat_operand_type =
1172 non_splat_operand.getType().cast<ShapedType>();
1173 // If the other operand's shape does not equal to the result shape, then we
1174 // cannot scalarize the splat constant because the result shape relies on
1175 // the splat constant op's shape for broadcasting.
1176 if (!non_splat_operand_type.hasStaticShape() ||
1177 non_splat_operand_type.getShape() != result_type.getShape() ||
1178 non_splat_operand_type.getRank() > 4) {
1179 return failure();
1180 }
1181
1182 // If non-splat operand is not fusable affine ops, then no need to apply
1183 // this transformation.
1184 if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
1185 return failure();
1186 }
1187
1188 // Creates a new scalar constant op using the splat value.
1189 mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
1190 auto scalar_elements_attr = DenseElementsAttr::get(
1191 RankedTensorType::get({},
1192 splat_elements_attr.getType().getElementType()),
1193 splat_elements_attr.getSplatValue<mlir::Attribute>());
1194
1195 auto scalar_constant_op = rewriter.create<arith::ConstantOp>(
1196 splat_operand.getLoc(), scalar_elements_attr.getType(),
1197 scalar_elements_attr);
1198
1199 binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
1200 return success();
1201 }
1202
1203 private:
1204 // Returns true if this value is a splat constant op which can be scalarized.
1205 // Also returns the elements attr if this value is indeed a splat constant.
IsScalarizableSplatConstantmlir::TFL::__anonacf349380111::ScalarizeSplatConstantForBroadcastableOps1206 bool IsScalarizableSplatConstant(mlir::Value value,
1207 DenseElementsAttr *elements_attr) const {
1208 if (!matchPattern(value, m_Constant(elements_attr))) {
1209 return false;
1210 }
1211 auto element_type = value.getType().cast<ShapedType>().getElementType();
1212 // Ignore per-axis quantized constants because after converting to scalar,
1213 // we will lose per-axis qantization parameter.
1214 if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
1215 return false;
1216 }
1217 if (IsScalar(value)) {
1218 return false;
1219 }
1220 return elements_attr->isSplat();
1221 }
1222
1223 // If this type is a scalar shaped type.
IsScalarmlir::TFL::__anonacf349380111::ScalarizeSplatConstantForBroadcastableOps1224 bool IsScalar(mlir::Value value) const {
1225 auto type = value.getType().dyn_cast<ShapedType>();
1226 if (!type) {
1227 return false;
1228 }
1229 if (!type.hasStaticShape()) {
1230 return false;
1231 }
1232 return type.getNumElements() == 1;
1233 }
1234
1235 // Returns true if we can fuse an affine op with consuming binary op.
CanFuseAffineOpmlir::TFL::__anonacf349380111::ScalarizeSplatConstantForBroadcastableOps1236 bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
1237 if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
1238 TFL::FullyConnectedOp>(affine_op)) {
1239 return false;
1240 }
1241 DenseElementsAttr value;
1242 // Check that bias are constants if not none.
1243 Value bias = affine_op->getOperand(2);
1244 if (!bias.getType().isa<NoneType>() &&
1245 !matchPattern(bias, m_Constant(&value))) {
1246 return false;
1247 }
1248 // If the binary op is mul/div, also check that filter is constant.
1249 if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
1250 !matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
1251 return false;
1252 }
1253
1254 // We can only fuse F32/BF16.
1255 auto is_fusable_type = [](Type t) {
1256 Type element_type = t;
1257 if (auto shaped_type = t.dyn_cast<ShapedType>()) {
1258 element_type = shaped_type.getElementType();
1259 }
1260 return element_type.isBF16() || element_type.isF32();
1261 };
1262 for (Type t : binary_op->getOperandTypes()) {
1263 if (!is_fusable_type(t)) {
1264 return false;
1265 }
1266 }
1267
1268 return true;
1269 }
1270 };
1271
1272 using ScalarizeSplatConstantForSub =
1273 ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
1274 using ScalarizeSplatConstantForAdd =
1275 ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
1276 using ScalarizeSplatConstantForMul =
1277 ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
1278 using ScalarizeSplatConstantForDiv =
1279 ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
1280
1281 struct ConvertTrivialTransposeOpToReshapeOp
1282 : public OpRewritePattern<TFL::TransposeOp> {
1283 using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
1284
matchAndRewritemlir::TFL::__anonacf349380111::ConvertTrivialTransposeOpToReshapeOp1285 LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op,
1286 PatternRewriter &rewriter) const override {
1287 auto input_type = transpose_op.input().getType().cast<ShapedType>();
1288 auto output_type = transpose_op.output().getType().cast<ShapedType>();
1289 // It's possible to know if the transformation is safe only if the input
1290 // & output shapes are fully known and permutation is a constant.
1291 if (!input_type.hasStaticShape() || !output_type.hasStaticShape())
1292 return failure();
1293 Value perm = transpose_op.perm();
1294 DenseElementsAttr perm_values_attr;
1295 if (!matchPattern(perm, m_Constant(&perm_values_attr))) return failure();
1296
1297 auto input_shape = input_type.getShape();
1298 SmallVector<int64_t, 8> perm_values;
1299 for (const auto &dim : perm_values_attr.getValues<APInt>())
1300 perm_values.push_back(dim.getSExtValue());
1301
1302 // This should never happen unless the input graph is malformed.
1303 if (input_shape.size() != perm_values.size()) {
1304 transpose_op.emitError(
1305 "TransposeOP has inconsistent input and perm values.");
1306 }
1307
1308 SmallVector<int, 8> old_major_index_ordering;
1309 SmallVector<int, 8> new_major_index_ordering;
1310 for (int i = 0, end = input_shape.size(); i < end; i++) {
1311 if (input_shape[i] != 1) {
1312 old_major_index_ordering.push_back(i);
1313 }
1314
1315 if (input_shape[perm_values[i]] != 1) {
1316 new_major_index_ordering.push_back(perm_values[i]);
1317 }
1318 }
1319 if (old_major_index_ordering != new_major_index_ordering) {
1320 return failure();
1321 }
1322
1323 // Rewrite.
1324 Location loc = transpose_op.getLoc();
1325
1326 SmallVector<int32_t, 8> output_shape_values;
1327 for (auto dim : output_type.getShape()) {
1328 output_shape_values.push_back(dim);
1329 }
1330 auto type = mlir::RankedTensorType::get(output_shape_values.size(),
1331 rewriter.getIntegerType(32));
1332 auto new_shape_attr =
1333 mlir::DenseIntElementsAttr::get(type, output_shape_values);
1334 auto new_shape = rewriter.create<TF::ConstOp>(loc, new_shape_attr);
1335
1336 rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(
1337 transpose_op, transpose_op.output().getType(), transpose_op.input(),
1338 new_shape);
1339
1340 return success();
1341 }
1342 };
1343
1344 // Remove Reshape before FullyConnected when `keep_num_dims=false` and Reshape
1345 // does not alter the last dimension as FullyConnected will collapse all other
1346 // dimensions into a single dimension. For example,
1347 //
1348 // %shape = arith.constant dense<[1, 128, 64]> : tensor<3xi32>
1349 // %reshape = tfl.reshape(%input, %shape) // %input: tensor<128x64xf32>
1350 // %fc = tfl.fully_connected(%reshape, %filter, %bias)
1351 // {keep_num_dims = false, weights_format = "DEFAULT"}
1352 //
1353 // can be canonicalized to
1354 //
1355 // %fc = tfl.fully_connected(%input, %filter, %bias)
1356 // {keep_num_dims = false, weights_format = "DEFAULT"}
1357 struct RemoveReshapeBeforeFullyConnected
1358 : public OpRewritePattern<TFL::FullyConnectedOp> {
1359 using OpRewritePattern<TFL::FullyConnectedOp>::OpRewritePattern;
1360
matchAndRewritemlir::TFL::__anonacf349380111::RemoveReshapeBeforeFullyConnected1361 LogicalResult matchAndRewrite(TFL::FullyConnectedOp fully_connected_op,
1362 PatternRewriter &) const override {
1363 auto input = fully_connected_op.input();
1364 auto input_ty = input.getType().dyn_cast<ShapedType>();
1365 auto output_ty = fully_connected_op.output()[0]
1366 .getType()
1367 .template dyn_cast<ShapedType>();
1368 if (!input_ty.hasStaticShape() ||
1369 fully_connected_op.weights_format() != "DEFAULT" ||
1370 fully_connected_op.keep_num_dims() || !output_ty.hasStaticShape() ||
1371 output_ty.getRank() != 2) {
1372 return failure();
1373 }
1374
1375 auto reshape_op = input.getDefiningOp<TFL::ReshapeOp>();
1376 if (!reshape_op) return failure();
1377
1378 // Check if the last dimension does not change after reshape.
1379 auto reshape_input = reshape_op.input();
1380 auto reshape_input_ty = reshape_input.getType().dyn_cast<ShapedType>();
1381 if (!reshape_input_ty.hasStaticShape() || input_ty.getRank() == 0 ||
1382 reshape_input_ty.getRank() == 0 ||
1383 input_ty.getDimSize(input_ty.getRank() - 1) !=
1384 reshape_input_ty.getDimSize(reshape_input_ty.getRank() - 1)) {
1385 return failure();
1386 }
1387
1388 // Connect the input to the one of reshape.
1389 fully_connected_op.setOperand(0, reshape_input);
1390 return success();
1391 }
1392 };
1393
1394 // Remove Reshape after FullyConnected when `keep_num_dims=false`, the Reshape
1395 // does not alter the last dimension and it restores the batch dimensions
1396 // collapsed by the FullyConnected op due to `keep_num_dims=false`. For example,
1397 //
1398 // // %input: tensor<4x16x32xf32>
1399 // %fc = tfl.fully_connected(%input, %filter, %bias)
1400 // {keep_num_dims = false, weights_format = "DEFAULT"}
1401 // %shape = arith.constant dense<[4, 16, 32]> : tensor<3xi32>
1402 // %rs = tfl.reshape(%fc, %shape)
1403 //
1404 // can be canonicalized to
1405 //
1406 // %fc = tfl.fully_connected(%input, %filter, %bias)
1407 // {keep_num_dims = true, weights_format = "DEFAULT"}
1408 struct RemoveReshapeAfterFullyConnected
1409 : public OpRewritePattern<TFL::ReshapeOp> {
1410 using OpRewritePattern::OpRewritePattern;
1411
matchAndRewritemlir::TFL::__anonacf349380111::RemoveReshapeAfterFullyConnected1412 LogicalResult matchAndRewrite(TFL::ReshapeOp reshape_op,
1413 PatternRewriter &rewriter) const override {
1414 auto fully_connected_op = llvm::dyn_cast_or_null<TFL::FullyConnectedOp>(
1415 reshape_op.input().getDefiningOp());
1416 if (!fully_connected_op || fully_connected_op.getNumResults() != 1 ||
1417 fully_connected_op.weights_format() != "DEFAULT" ||
1418 fully_connected_op.keep_num_dims())
1419 return failure();
1420 if (!reshape_op.input().hasOneUse()) return failure();
1421
1422 auto input_shape = fully_connected_op.input().getType().cast<ShapedType>();
1423 auto output_shape = fully_connected_op.getType(0).cast<ShapedType>();
1424 auto reshape_shape = reshape_op.getType().cast<ShapedType>();
1425 if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() ||
1426 !reshape_shape.hasStaticShape())
1427 return failure();
1428
1429 // Check that the reshape doesn't modify the last dimension and it restores
1430 // the input (batch) dimension with the exception of the feature (last)
1431 // dimension.
1432 if (output_shape.getShape().empty() || reshape_shape.getShape().empty() ||
1433 output_shape.getShape().back() != reshape_shape.getShape().back() ||
1434 input_shape.getShape().drop_back() !=
1435 reshape_shape.getShape().drop_back())
1436 return failure();
1437
1438 llvm::SmallVector<Type, 1> output_type{reshape_op.getType()};
1439 rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
1440 reshape_op, output_type, /*input=*/fully_connected_op.input(),
1441 /*filter=*/fully_connected_op.filter(),
1442 /*bias=*/fully_connected_op.bias(),
1443 /*fused_activation_function=*/
1444 fully_connected_op.fused_activation_function(),
1445 /*weights_format=*/fully_connected_op.weights_format(),
1446 /*keep_num_dims=*/true,
1447 /*asymmetric_quantize_inputs=*/
1448 fully_connected_op.asymmetric_quantize_inputsAttr());
1449 return success();
1450 }
1451 };
1452
1453 // Fuses Unpack with proceeding Concatenation to Reshape if output type has
1454 // static shape and activation function is none. For example:
1455 //
1456 // // %input: tensor<1x3x2xf32>
1457 // %unpack:3 = "tfl.unpack"(%input) {axis = 1 : i32, num = 3 : i32}
1458 // %res = "tfl.concatenation"(%unpack#0, %unpack#1, %unpack#2)
1459 // {axis = -1 : i32, fused_activation_function = "NONE"}
1460 //
1461 // can be optimized to
1462 //
1463 // %cst = arith.constant dense<[1, 6]> : tensor<2xi32>
1464 // %res = "tfl.reshape"(%input, %cst)
1465 struct FuseUnpackAndConcatToReshape
1466 : public OpRewritePattern<TFL::ConcatenationOp> {
1467 using OpRewritePattern::OpRewritePattern;
1468
matchAndRewritemlir::TFL::__anonacf349380111::FuseUnpackAndConcatToReshape1469 LogicalResult matchAndRewrite(TFL::ConcatenationOp concat_op,
1470 PatternRewriter &rewriter) const override {
1471 if (concat_op.fused_activation_function() != "NONE") {
1472 return failure();
1473 }
1474
1475 // Checks all operands come from the same unpack op.
1476 auto first_operand = concat_op.values().front();
1477 auto unpack_op =
1478 dyn_cast_or_null<TFL::UnpackOp>(first_operand.getDefiningOp());
1479 if (!unpack_op || unpack_op.getNumResults() != concat_op.getNumOperands()) {
1480 return failure();
1481 }
1482 for (auto &index_and_value : llvm::enumerate(concat_op.values())) {
1483 if (index_and_value.value() !=
1484 unpack_op.getResult(index_and_value.index())) {
1485 return failure();
1486 }
1487 }
1488
1489 auto output_type = concat_op.getType().cast<ShapedType>();
1490 if (!output_type.hasStaticShape()) {
1491 return failure();
1492 }
1493
1494 auto new_shape_array = output_type.getShape();
1495 // This is to workaround the unnecessary cast i64 -> i32.
1496 SmallVector<int32_t, 4> new_shape_array_i32;
1497 for (auto size : new_shape_array) {
1498 new_shape_array_i32.push_back(static_cast<int32_t>(size));
1499 }
1500 auto new_shape = rewriter.create<TFL::ConstOp>(
1501 concat_op.getLoc(),
1502 DenseIntElementsAttr::get(
1503 RankedTensorType::get(new_shape_array_i32.size(),
1504 rewriter.getIntegerType(32)),
1505 new_shape_array_i32));
1506
1507 rewriter.replaceOpWithNewOp<TFL::ReshapeOp>(concat_op, output_type,
1508 unpack_op.input(), new_shape);
1509 return success();
1510 }
1511 };
1512
1513 // Reduce the K of a TopKV2Op for the following case.
1514 //
1515 // values, indices = tfl.topkv2(%inputs, K)
1516 // %1 = tfl.slice(values, 0, k)
1517 // %2 = tfl.slice(indices,0, k)
1518 // .... (values and indices only used for %1 and %2)
1519 //
1520 // %1 or %2 can be absent. If values and indices are only used here,
1521 // this pattern can be replaced with (conceptually)
1522 //
1523 // %values, %indices = tfl.topkv2(%inputs, k)
1524 // replace all use of %1 with values
1525 // replace all use of %2 with indices
1526 //
1527 struct OptimizeTopK : public OpRewritePattern<TFL::TopKV2Op> {
1528 using OpRewritePattern::OpRewritePattern;
1529
1530 // It computes the last dim k of slice size of value.user.
1531 // If value has no use then return 0.
ComputeSliceKmlir::TFL::__anonacf349380111::OptimizeTopK1532 llvm::Optional<int32_t> ComputeSliceK(Value value) const {
1533 if (value.use_empty()) return 0;
1534 auto slice_op =
1535 llvm::dyn_cast_or_null<TFL::SliceOp>(value.getUses().begin().getUser());
1536 // We only match for the case where value is used by SliceOp.
1537 if (!slice_op) return llvm::None;
1538 DenseElementsAttr begin;
1539 DenseElementsAttr size;
1540 if (!matchPattern(slice_op->getOperand(1), m_Constant(&begin)) ||
1541 !matchPattern(slice_op->getOperand(2), m_Constant(&size)))
1542 return llvm::None;
1543
1544 // Check if "begin" is a zero tensor.
1545 for (auto begin_idx : begin.getValues<APInt>())
1546 if (begin_idx != 0) return llvm::None;
1547
1548 // Check if "size" is equal to slice_op.input.shape except
1549 // for last dimension.
1550 // It can be done by verifying the number of elements:
1551 // i.e., num_input/input_last_dim = num_result/k
1552 auto input_ty = value.getType().dyn_cast_or_null<ShapedType>();
1553 auto result_ty = slice_op.getType().dyn_cast<ShapedType>();
1554 if (!input_ty || !result_ty) return llvm::None;
1555 if (!input_ty.hasStaticShape() || !result_ty.hasStaticShape())
1556 return llvm::None;
1557 if (!input_ty.getRank() || !result_ty.getRank()) return llvm::None;
1558 int num_input = input_ty.getNumElements();
1559 int input_last_dim = input_ty.getShape().back();
1560 if (input_last_dim < 1) return llvm::None;
1561 int num_result = result_ty.getNumElements();
1562 auto size_last = *(--size.value_end<APInt>());
1563 int32_t k = size_last.getSExtValue();
1564 if (num_input / input_last_dim * k != num_result) return llvm::None;
1565 // We don't match sliceOp with last dim size = 0.
1566 if (!k) return llvm::None;
1567 return k;
1568 }
1569
matchAndRewritemlir::TFL::__anonacf349380111::OptimizeTopK1570 LogicalResult matchAndRewrite(TFL::TopKV2Op op,
1571 PatternRewriter &rewriter) const override {
1572 auto values = op.values();
1573 auto indices = op.indices();
1574 // op.values() and op.indices() cannot be used more than once.
1575 if (!values.hasOneUse() && !values.use_empty()) return failure();
1576 if (!indices.hasOneUse() && !indices.use_empty()) return failure();
1577
1578 auto k_values_or = ComputeSliceK(values);
1579 auto k_indices_or = ComputeSliceK(indices);
1580 if (!k_values_or.has_value() || !k_indices_or.has_value()) return failure();
1581 int32_t k_values = k_values_or.getValue();
1582 int32_t k_indices = k_indices_or.getValue();
1583 // We don't match two SliceOp with different sizes.
1584 if (k_values != k_indices && !values.use_empty() && !indices.use_empty())
1585 return failure();
1586
1587 // Start replacing.
1588 auto k = !values.use_empty() ? k_values : k_indices;
1589 // Build scalar tensor k.
1590 auto k_ty = mlir::RankedTensorType::get({}, rewriter.getIntegerType(32));
1591 Value k_cst = rewriter.create<TFL::ConstOp>(
1592 op.getLoc(), DenseElementsAttr::get(k_ty, k));
1593 // Compute new result types.
1594 auto values_ty = values.getType().dyn_cast<ShapedType>();
1595 auto indices_ty = indices.getType().dyn_cast<ShapedType>();
1596 auto shape = std::vector<int64_t>();
1597 for (auto d : values_ty.getShape().drop_back()) {
1598 shape.push_back(d);
1599 }
1600 shape.push_back(static_cast<int64_t>(k));
1601 auto new_values_ty =
1602 mlir::RankedTensorType::get(shape, values_ty.getElementType());
1603 auto new_indices_ty =
1604 mlir::RankedTensorType::get(shape, indices_ty.getElementType());
1605 TFL::TopKV2Op top_k_op = rewriter.create<TFL::TopKV2Op>(
1606 op.getLoc(), new_values_ty, new_indices_ty, op->getOperand(0), k_cst);
1607
1608 // Remove original ops (topk, Slice, Slice).
1609 if (!values.use_empty()) {
1610 auto values_slice_op = llvm::dyn_cast_or_null<TFL::SliceOp>(
1611 values.getUses().begin().getUser());
1612 values_slice_op.getResult().replaceAllUsesWith(top_k_op.values());
1613 values_slice_op.erase();
1614 }
1615 if (!indices.use_empty()) {
1616 auto indices_slice_op = llvm::dyn_cast_or_null<TFL::SliceOp>(
1617 indices.getUses().begin().getUser());
1618 indices_slice_op.getResult().replaceAllUsesWith(top_k_op.indices());
1619 indices_slice_op.erase();
1620 }
1621 op.erase();
1622 return success();
1623 }
1624 };
1625
1626 using FuseBinaryOpToFollowingFullyConnected =
1627 FuseBinaryOpToFollowingAffineOp<FullyConnectedOp>;
1628 using FuseBinaryOpToFollowingDepthwiseConv2D =
1629 FuseBinaryOpToFollowingAffineOp<DepthwiseConv2DOp>;
1630 using FuseBinaryOpToFollowingConv2D = FuseBinaryOpToFollowingAffineOp<Conv2DOp>;
1631
1632 // Adds canonicalization patterns to the list of patterns.
AddCanonicalizationPatterns(MLIRContext * context,RewritePatternSet * patterns)1633 void AddCanonicalizationPatterns(MLIRContext *context,
1634 RewritePatternSet *patterns) {
1635 for (auto op : context->getRegisteredOperations())
1636 op.getCanonicalizationPatterns(*patterns, context);
1637 }
1638
runOnOperation()1639 void OptimizePass::runOnOperation() {
1640 RewritePatternSet patterns(&getContext());
1641 auto *ctx = &getContext();
1642 auto func = getOperation();
1643
1644 // Merge reshapes into fully connected ops before we start moving them past
1645 // binary ops.
1646 RewritePatternSet phase_0_patterns(&getContext());
1647 phase_0_patterns
1648 .add<RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected>(
1649 ctx);
1650 (void)applyPatternsAndFoldGreedily(func, std::move(phase_0_patterns));
1651
1652 // Potentially the binary ops might be fused together, like hard_swish, thus
1653 // we explore these potentially first and then fuse the binary ops with the
1654 // following ops in a second pattern match.
1655 TFL::populateWithGenerated(patterns);
1656 patterns.add<FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
1657 FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
1658 FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
1659 FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
1660 FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>>(ctx);
1661 if (this->enable_canonicalization_)
1662 AddCanonicalizationPatterns(ctx, &patterns);
1663 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
1664
1665 // Fuse the binary ops with the following ops.
1666 RewritePatternSet phase_2_patterns(&getContext());
1667 TFL::populateWithGenerated(phase_2_patterns);
1668 phase_2_patterns.add<
1669 ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
1670 ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
1671 FuseFullyConnectedAndAdd, FuseAddAndFullyConnected,
1672 FuseFullyConnectedAndMul, FuseMulAndFullyConnected,
1673 FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
1674 FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
1675 FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,
1676 FuseBinaryOpToFollowingConv2D, FuseBinaryOpToFollowingDepthwiseConv2D,
1677 FuseBinaryOpToFollowingFullyConnected, FuseConv2DAndMulWithQDQs,
1678 FuseDepthwiseConv2DAndMulWithQDQs, ConvertTrivialTransposeOpToReshapeOp,
1679 RemoveReshapeAfterFullyConnected, RemoveReshapeBeforeFullyConnected,
1680 FuseUnpackAndConcatToReshape, OptimizeTopK>(ctx);
1681 if (this->enable_canonicalization_)
1682 AddCanonicalizationPatterns(ctx, &phase_2_patterns);
1683 (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns));
1684 }
1685 } // namespace
1686
1687 // Creates an instance of the TensorFlow Lite dialect Optimize pass.
CreateOptimizePass(bool enable_canonicalization)1688 std::unique_ptr<OperationPass<func::FuncOp>> CreateOptimizePass(
1689 bool enable_canonicalization) {
1690 return std::make_unique<OptimizePass>(enable_canonicalization);
1691 }
1692
CreateOptimizePass()1693 std::unique_ptr<OperationPass<func::FuncOp>> CreateOptimizePass() {
1694 return std::make_unique<OptimizePass>();
1695 }
1696
1697 } // namespace TFL
1698 } // namespace mlir
1699