xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Legalize TensorFlow to TOSA
17 
18 #include <climits>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23 
24 #include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
25 #include "mlir/Dialect/Tosa/IR/TosaOps.h"  // from @llvm-project
26 #include "mlir/Support/LLVM.h"  // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
31 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
32 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
33 
34 #define PASS_NAME "tosa-legalize-tf"
35 #define DEBUG_TYPE PASS_NAME
36 
37 namespace mlir {
38 namespace tosa {
39 namespace {
40 
41 // Performs lowering to TOSA dialect
42 class LegalizeTF : public TosaLegalizeTFPassBase<LegalizeTF> {
43  public:
LegalizeTF()44   explicit LegalizeTF() {}
45   void runOnOperation() override;
46 };
47 
48 // All the Pat<> lowering mappings.
49 #include "tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.inc"
50 
51 #define DECL_CONVERT_OP(tf_op)                                               \
52   struct ConvertTF##tf_op##Op : public RewritePattern {                      \
53     explicit ConvertTF##tf_op##Op(MLIRContext* context)                      \
54         : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {}   \
55     LogicalResult matchAndRewrite(Operation* op,                             \
56                                   PatternRewriter& rewriter) const override; \
57   }
58 
59 // All the explcitly implemented complex lowerings.
60 DECL_CONVERT_OP(MatMul);
61 DECL_CONVERT_OP(Relu);
62 DECL_CONVERT_OP(Relu6);
63 DECL_CONVERT_OP(Equal);
64 DECL_CONVERT_OP(NotEqual);
65 DECL_CONVERT_OP(Greater);
66 DECL_CONVERT_OP(GreaterEqual);
67 DECL_CONVERT_OP(Add);
68 DECL_CONVERT_OP(AddV2);
69 DECL_CONVERT_OP(AddN);
70 DECL_CONVERT_OP(Sub);
71 DECL_CONVERT_OP(Mul);
72 DECL_CONVERT_OP(Square);
73 DECL_CONVERT_OP(SquaredDifference);
74 DECL_CONVERT_OP(Round);
75 DECL_CONVERT_OP(FloorDiv);
76 DECL_CONVERT_OP(FloorMod);
77 DECL_CONVERT_OP(Assert);
78 DECL_CONVERT_OP(Maximum);
79 DECL_CONVERT_OP(Minimum);
80 DECL_CONVERT_OP(RealDiv);
81 DECL_CONVERT_OP(ArgMax);
82 DECL_CONVERT_OP(AvgPool);
83 DECL_CONVERT_OP(MaxPool);
84 DECL_CONVERT_OP(ConcatV2);
85 DECL_CONVERT_OP(Reshape);
86 DECL_CONVERT_OP(Rank);
87 DECL_CONVERT_OP(Shape);
88 DECL_CONVERT_OP(ExpandDims);
89 DECL_CONVERT_OP(Squeeze);
90 DECL_CONVERT_OP(Fill);
91 DECL_CONVERT_OP(Conv2D);
92 DECL_CONVERT_OP(DepthwiseConv2dNative);
93 DECL_CONVERT_OP(Conv2DBackpropInput);
94 DECL_CONVERT_OP(Elu);
95 DECL_CONVERT_OP(Softmax);
96 DECL_CONVERT_OP(LogSoftmax);
97 DECL_CONVERT_OP(All);
98 DECL_CONVERT_OP(Any);
99 DECL_CONVERT_OP(Max);
100 DECL_CONVERT_OP(Min);
101 DECL_CONVERT_OP(Mean);
102 DECL_CONVERT_OP(Prod);
103 DECL_CONVERT_OP(Sum);
104 DECL_CONVERT_OP(FusedBatchNorm);
105 DECL_CONVERT_OP(FusedBatchNormV3);
106 DECL_CONVERT_OP(BiasAdd);
107 DECL_CONVERT_OP(Split);
108 DECL_CONVERT_OP(SplitV);
109 DECL_CONVERT_OP(Pack);
110 DECL_CONVERT_OP(Unpack);
111 DECL_CONVERT_OP(Transpose);
112 DECL_CONVERT_OP(Tile);
113 DECL_CONVERT_OP(Slice);
114 DECL_CONVERT_OP(StridedSlice);
115 DECL_CONVERT_OP(Less);
116 DECL_CONVERT_OP(LessEqual);
117 DECL_CONVERT_OP(Pad);
118 DECL_CONVERT_OP(ResizeBilinear);
119 DECL_CONVERT_OP(ResizeNearestNeighbor);
120 DECL_CONVERT_OP(Gather);
121 DECL_CONVERT_OP(GatherV2);
122 DECL_CONVERT_OP(GatherNd);
123 DECL_CONVERT_OP(SelectV2);
124 DECL_CONVERT_OP(SpaceToDepth);
125 DECL_CONVERT_OP(DepthToSpace);
126 DECL_CONVERT_OP(SpaceToBatchND);
127 DECL_CONVERT_OP(BatchToSpaceND);
128 DECL_CONVERT_OP(ZerosLike);
129 DECL_CONVERT_OP(Sigmoid);
130 DECL_CONVERT_OP(Tanh);
131 DECL_CONVERT_OP(LeakyRelu);
132 DECL_CONVERT_OP(Neg);
133 DECL_CONVERT_OP(StopGradient);
134 DECL_CONVERT_OP(ReverseV2);
135 DECL_CONVERT_OP(FakeQuantWithMinMaxArgs);
136 DECL_CONVERT_OP(FakeQuantWithMinMaxVars);
137 DECL_CONVERT_OP(LeftShift);
138 DECL_CONVERT_OP(RightShift);
139 DECL_CONVERT_OP(OneHot);
140 DECL_CONVERT_OP(BatchMatMulV2);
141 #undef DECL_CONVERT_OP
142 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const143 LogicalResult ConvertTFReluOp::matchAndRewrite(
144     Operation* op, PatternRewriter& rewriter) const {
145   auto tf_relu_op = cast<TF::ReluOp>(op);
146 
147   TensorType output_type =
148       tf_relu_op.getResult().getType().dyn_cast<TensorType>();
149   // Not a tensor output
150   if (!output_type) return failure();
151 
152   CreateReplaceOpAndInfer<tosa::ClampOp>(
153       rewriter, op, output_type, tf_relu_op.features(),
154       rewriter.getI64IntegerAttr(0),
155       rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
156       rewriter.getF32FloatAttr(0.0f),
157       rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
158   return success();
159 }
160 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const161 LogicalResult ConvertTFRelu6Op::matchAndRewrite(
162     Operation* op, PatternRewriter& rewriter) const {
163   auto tf_relu6_op = cast<TF::Relu6Op>(op);
164 
165   TensorType output_type =
166       tf_relu6_op.getResult().getType().dyn_cast<TensorType>();
167   // Not a tensor output
168   if (!output_type) return failure();
169 
170   CreateReplaceOpAndInfer<tosa::ClampOp>(
171       rewriter, op, output_type, tf_relu6_op.features(),
172       rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(6),
173       rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f));
174   return success();
175 }
176 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const177 LogicalResult ConvertTFEqualOp::matchAndRewrite(
178     Operation* op, PatternRewriter& rewriter) const {
179   auto tf_equal_op = cast<TF::EqualOp>(op);
180 
181   TensorType output_type =
182       tf_equal_op.getResult().getType().dyn_cast<TensorType>();
183   // Not a tensor output
184   if (!output_type) return failure();
185 
186   CreateReplaceOpAndInfer<tosa::EqualOp>(rewriter, op, output_type,
187                                          tf_equal_op.x(), tf_equal_op.y());
188   return success();
189 }
190 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const191 LogicalResult ConvertTFNotEqualOp::matchAndRewrite(
192     Operation* op, PatternRewriter& rewriter) const {
193   auto tf_not_equal_op = cast<TF::NotEqualOp>(op);
194 
195   TensorType output_type =
196       tf_not_equal_op.getResult().getType().dyn_cast<TensorType>();
197   // Not a tensor output
198   if (!output_type) return failure();
199 
200   auto op1_equal_in =
201       CreateOpAndInfer<tosa::EqualOp>(rewriter, op->getLoc(), output_type,
202                                       tf_not_equal_op.x(), tf_not_equal_op.y());
203 
204   auto op2_not_op1 = CreateOpAndInfer<tosa::LogicalNotOp>(
205       rewriter, op->getLoc(), output_type, op1_equal_in.getResult());
206 
207   rewriter.replaceOp(op, {op2_not_op1.getResult()});
208 
209   return success();
210 }
211 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const212 LogicalResult ConvertTFGreaterOp::matchAndRewrite(
213     Operation* op, PatternRewriter& rewriter) const {
214   auto tf_greater_op = cast<TF::GreaterOp>(op);
215 
216   TensorType output_type =
217       tf_greater_op.getResult().getType().dyn_cast<TensorType>();
218   // Not a tensor output
219   if (!output_type) return failure();
220 
221   CreateReplaceOpAndInfer<tosa::GreaterOp>(
222       rewriter, op, output_type, tf_greater_op.x(), tf_greater_op.y());
223   return success();
224 }
225 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const226 LogicalResult ConvertTFGreaterEqualOp::matchAndRewrite(
227     Operation* op, PatternRewriter& rewriter) const {
228   auto tf_greater_equal_op = cast<TF::GreaterEqualOp>(op);
229 
230   TensorType output_type =
231       tf_greater_equal_op.getResult().getType().dyn_cast<TensorType>();
232   // Not a tensor output
233   if (!output_type) return failure();
234 
235   CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(rewriter, op, output_type,
236                                                 tf_greater_equal_op.x(),
237                                                 tf_greater_equal_op.y());
238   return success();
239 }
240 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const241 LogicalResult ConvertTFAddOp::matchAndRewrite(Operation* op,
242                                               PatternRewriter& rewriter) const {
243   auto tf_add_op = cast<TF::AddOp>(op);
244 
245   TensorType output_type =
246       tf_add_op.getResult().getType().dyn_cast<TensorType>();
247   // Not a tensor output
248   if (!output_type) return failure();
249 
250   CreateReplaceOpAndInfer<tosa::AddOp>(rewriter, op, output_type, tf_add_op.x(),
251                                        tf_add_op.y());
252   return success();
253 }
254 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const255 LogicalResult ConvertTFAddV2Op::matchAndRewrite(
256     Operation* op, PatternRewriter& rewriter) const {
257   auto tf_addv2_op = cast<TF::AddV2Op>(op);
258 
259   TensorType output_type =
260       tf_addv2_op.getResult().getType().dyn_cast<TensorType>();
261   // Not a tensor output
262   if (!output_type) return failure();
263 
264   CreateReplaceOpAndInfer<tosa::AddOp>(rewriter, op, output_type,
265                                        tf_addv2_op.x(), tf_addv2_op.y());
266   return success();
267 }
268 
269 // AddN is commutative
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const270 LogicalResult ConvertTFAddNOp::matchAndRewrite(
271     Operation* op, PatternRewriter& rewriter) const {
272   auto tf_addn_op = cast<TF::AddNOp>(op);
273 
274   TensorType output_type =
275       tf_addn_op.getResult().getType().dyn_cast<TensorType>();
276   // Not a tensor output
277   if (!output_type) return failure();
278 
279   SmallVector<Value> inputs(tf_addn_op.inputs());
280 
281   assert(inputs.size() >= 2);
282 
283   auto newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(),
284                                              output_type, inputs[0], inputs[1]);
285   for (int i = 2; i < inputs.size(); i++) {
286     newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
287                                           inputs[i], newOp.getResult());
288   }
289 
290   rewriter.replaceOp(op, {newOp.getResult()});
291 
292   return success();
293 }
294 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const295 LogicalResult ConvertTFSubOp::matchAndRewrite(Operation* op,
296                                               PatternRewriter& rewriter) const {
297   auto tf_sub_op = cast<TF::SubOp>(op);
298 
299   TensorType output_type =
300       tf_sub_op.getResult().getType().dyn_cast<TensorType>();
301   // Not a tensor output
302   if (!output_type) return failure();
303 
304   CreateReplaceOpAndInfer<tosa::SubOp>(rewriter, op, output_type, tf_sub_op.x(),
305                                        tf_sub_op.y());
306   return success();
307 }
308 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const309 LogicalResult ConvertTFMulOp::matchAndRewrite(Operation* op,
310                                               PatternRewriter& rewriter) const {
311   auto tf_mul_op = cast<TF::MulOp>(op);
312 
313   llvm::Optional<Value> result = convertMultiplyOp(
314       rewriter, op, tf_mul_op.getResult(), tf_mul_op.x(), tf_mul_op.y());
315 
316   if (!result) return failure();
317 
318   rewriter.replaceOp(op, {result.getValue()});
319   return success();
320 }
321 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const322 LogicalResult ConvertTFSquareOp::matchAndRewrite(
323     Operation* op, PatternRewriter& rewriter) const {
324   auto tf_square_op = cast<TF::SquareOp>(op);
325 
326   llvm::Optional<Value> result =
327       convertMultiplyOp(rewriter, op, tf_square_op.getResult(),
328                         tf_square_op.x(), tf_square_op.x());
329 
330   if (!result) return failure();
331 
332   rewriter.replaceOp(op, {result.getValue()});
333   return success();
334 }
335 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const336 LogicalResult ConvertTFSquaredDifferenceOp::matchAndRewrite(
337     Operation* op, PatternRewriter& rewriter) const {
338   auto tf_squared_op = cast<TF::SquaredDifferenceOp>(op);
339 
340   llvm::Optional<Value> result =
341       convertSquaredDifferenceOp(rewriter, op, tf_squared_op.getResult(),
342                                  tf_squared_op.x(), tf_squared_op.y());
343 
344   if (!result) return failure();
345 
346   rewriter.replaceOp(op, {result.getValue()});
347   return success();
348 }
349 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const350 LogicalResult ConvertTFRoundOp::matchAndRewrite(
351     Operation* op, PatternRewriter& rewriter) const {
352   auto tf_round_op = cast<TF::RoundOp>(op);
353 
354   TensorType input_type = tf_round_op.x().getType().dyn_cast<TensorType>();
355   if (!input_type) {
356     return op->emitOpError("Round: input not tensor type");
357   }
358 
359   if (input_type.getElementType().isa<FloatType>()) {
360     llvm::Optional<Value> result =
361         convertRoundOp(rewriter, op, tf_round_op.getResult(), tf_round_op.x());
362 
363     if (!result) return failure();
364 
365     rewriter.replaceOp(op, {result.getValue()});
366     return success();
367 
368   } else {
369     tf_round_op.replaceAllUsesWith(tf_round_op.x());
370     return success();
371   }
372 }
373 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const374 LogicalResult ConvertTFFloorDivOp::matchAndRewrite(
375     Operation* op, PatternRewriter& rewriter) const {
376   auto tf_floordiv_op = cast<TF::FloorDivOp>(op);
377 
378   llvm::Optional<Value> result =
379       convertFloorDivOp(rewriter, op, tf_floordiv_op.getResult(),
380                         tf_floordiv_op.x(), tf_floordiv_op.y());
381 
382   if (!result) return failure();
383 
384   rewriter.replaceOp(op, {result.getValue()});
385 
386   return success();
387 }
388 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const389 LogicalResult ConvertTFFloorModOp::matchAndRewrite(
390     Operation* op, PatternRewriter& rewriter) const {
391   auto tf_floormod_op = cast<TF::FloorModOp>(op);
392 
393   llvm::Optional<Value> result =
394       convertFloorModOp(rewriter, op, tf_floormod_op.getResult(),
395                         tf_floormod_op.x(), tf_floormod_op.y());
396 
397   if (!result) return failure();
398 
399   rewriter.replaceOp(op, {result.getValue()});
400 
401   return success();
402 }
403 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const404 LogicalResult ConvertTFAssertOp::matchAndRewrite(
405     Operation* op, PatternRewriter& rewriter) const {
406   op->dropAllReferences();
407   op->erase();
408   return success();
409 }
410 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const411 LogicalResult ConvertTFMaximumOp::matchAndRewrite(
412     Operation* op, PatternRewriter& rewriter) const {
413   auto tf_maximum_op = cast<TF::MaximumOp>(op);
414 
415   TensorType output_type =
416       tf_maximum_op.getResult().getType().dyn_cast<TensorType>();
417   // Not a tensor output
418   if (!output_type) return failure();
419 
420   CreateReplaceOpAndInfer<tosa::MaximumOp>(
421       rewriter, op, output_type, tf_maximum_op.x(), tf_maximum_op.y());
422   return success();
423 }
424 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const425 LogicalResult ConvertTFMinimumOp::matchAndRewrite(
426     Operation* op, PatternRewriter& rewriter) const {
427   auto tf_minimum_op = cast<TF::MinimumOp>(op);
428 
429   TensorType output_type =
430       tf_minimum_op.getResult().getType().dyn_cast<TensorType>();
431   // Not a tensor output
432   if (!output_type) return failure();
433 
434   CreateReplaceOpAndInfer<tosa::MinimumOp>(
435       rewriter, op, output_type, tf_minimum_op.x(), tf_minimum_op.y());
436   return success();
437 }
438 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const439 LogicalResult ConvertTFRealDivOp::matchAndRewrite(
440     Operation* op, PatternRewriter& rewriter) const {
441   auto tf_div_op = cast<TF::RealDivOp>(op);
442 
443   TensorType y_type = tf_div_op.y().getType().dyn_cast<TensorType>();
444   TensorType output_type =
445       tf_div_op.getResult().getType().dyn_cast<TensorType>();
446   // Not a tensor output
447   if (!output_type || !y_type) return failure();
448 
449   Type element_type = output_type.getElementType();
450 
451   if (element_type.isa<IntegerType>()) {
452     CreateReplaceOpAndInfer<tosa::DivOp>(rewriter, op, output_type,
453                                          tf_div_op.x(), tf_div_op.y());
454     return success();
455   }
456 
457   auto reciprocal_op = CreateOpAndInfer<tosa::ReciprocalOp>(
458       rewriter, op->getLoc(), tf_div_op.y().getType(), tf_div_op.y());
459 
460   auto mul_op = CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(),
461                                               output_type, tf_div_op.x(),
462                                               reciprocal_op.getResult(), 0);
463   rewriter.replaceOp(op, {mul_op.getResult()});
464 
465   return success();
466 }
467 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const468 LogicalResult ConvertTFArgMaxOp::matchAndRewrite(
469     Operation* op, PatternRewriter& rewriter) const {
470   auto tf_argmax_op = cast<TF::ArgMaxOp>(op);
471 
472   TensorType input_type = tf_argmax_op.input().getType().dyn_cast<TensorType>();
473   TensorType output_type =
474       tf_argmax_op.getResult().getType().dyn_cast<TensorType>();
475   // Not a tensor output
476   if (!output_type || !input_type) return failure();
477 
478   ElementsAttr axis_elems;
479   if (!matchPattern(tf_argmax_op.dimension(), m_Constant(&axis_elems)))
480     return failure();
481 
482   int32_t axis = axis_elems.getValues<IntegerAttr>()[0].getInt();
483   if (axis < 0) {
484     axis += input_type.getRank();
485   }
486 
487   if (axis < 0 || axis >= input_type.getRank()) {
488     return op->emitOpError("TFArgMax: invalid axis value");
489   }
490 
491   IntegerAttr axis_attr = rewriter.getI64IntegerAttr(axis);
492 
493   CreateReplaceOpAndInfer<tosa::ArgMaxOp>(rewriter, op, output_type,
494                                           tf_argmax_op.input(), axis_attr);
495 
496   return success();
497 }
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const498 LogicalResult ConvertTFAvgPoolOp::matchAndRewrite(
499     Operation* op, PatternRewriter& rewriter) const {
500   auto tf_avgpool_op = cast<TF::AvgPoolOp>(op);
501 
502   RankedTensorType input_type =
503       tf_avgpool_op.value().getType().dyn_cast<RankedTensorType>();
504   RankedTensorType output_type =
505       tf_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
506   // Not a ranked tensor output
507   if (!input_type || !output_type) return failure();
508 
509   auto tmpAttr = tf_avgpool_op.data_formatAttr();
510   if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
511 
512   ArrayAttr pad;
513   ArrayAttr stride;
514   ArrayAttr kernel;
515   {
516     auto tmpAttr = tf_avgpool_op.strides();
517     if (!tmpAttr) {
518       stride = rewriter.getI64ArrayAttr({1, 1});
519     } else {
520       // Note: hardcoded to NHWC for now
521       int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
522       int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
523       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
524     }
525   }
526   {
527     auto tmpAttr = tf_avgpool_op.ksize();
528     if (!tmpAttr) {
529       kernel = rewriter.getI64ArrayAttr({1, 1});
530     } else {
531       // Note: hardcoded to NHWC for now
532       int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
533       int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
534       kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
535     }
536   }
537   {
538     tensorflow::Padding tf_pad;
539     if (!GetPaddingFromString(tf_avgpool_op.padding().str(), &tf_pad).ok())
540       return failure();
541 
542     ArrayAttr dilation =
543         rewriter.getI64ArrayAttr({1, 1});  // Pooling has no non-unit dilation
544 
545     SmallVector<int64_t, 4> i64array;
546 
547     for (auto& elem : tf_avgpool_op.ksize()) {
548       int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
549       i64array.emplace_back(value);
550     }
551 
552     RankedTensorType filter_type = RankedTensorType::get(
553         llvm::makeArrayRef(i64array), rewriter.getIntegerType(64));
554 
555     if (!getPaddingValuesFromPadType(
556             tf_pad,
557             tensorflow::FORMAT_NHWC,  // TFLite only supports this
558             1,                        // tensorflow::FORMAT_OHWI,
559             input_type, filter_type, stride, dilation, rewriter, pad))
560       return failure();
561   }
562 
563   CreateReplaceOpAndInfer<tosa::AvgPool2dOp>(
564       rewriter, op, output_type, tf_avgpool_op.value(), kernel, stride, pad);
565   return success();
566 }
567 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const568 LogicalResult ConvertTFMaxPoolOp::matchAndRewrite(
569     Operation* op, PatternRewriter& rewriter) const {
570   auto tf_maxpool_op = cast<TF::MaxPoolOp>(op);
571 
572   RankedTensorType input_type =
573       tf_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
574   RankedTensorType output_type =
575       tf_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
576   // Not a ranked tensor output
577   if (!input_type || !output_type) return failure();
578 
579   auto tmpAttr = tf_maxpool_op.data_formatAttr();
580   if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
581 
582   ArrayAttr pad;
583   ArrayAttr stride;
584   ArrayAttr kernel;
585   {
586     auto tmpAttr = tf_maxpool_op.strides();
587     if (!tmpAttr) {
588       stride = rewriter.getI64ArrayAttr({1, 1});
589     } else {
590       // Note: hardcoded to NHWC for now
591       int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
592       int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
593       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
594     }
595   }
596   {
597     auto tmpAttr = tf_maxpool_op.ksize();
598     if (!tmpAttr) {
599       kernel = rewriter.getI64ArrayAttr({1, 1});
600     } else {
601       // Note: hardcoded to NHWC for now
602       int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
603       int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
604       kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
605     }
606   }
607   {
608     tensorflow::Padding tf_pad;
609     if (!GetPaddingFromString(tf_maxpool_op.padding().str(), &tf_pad).ok())
610       return failure();
611 
612     // Pooling has no non-unit dilation
613     ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
614 
615     SmallVector<int64_t, 4> i64array;
616 
617     for (auto& elem : tf_maxpool_op.ksize()) {
618       int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
619       i64array.emplace_back(value);
620     }
621 
622     RankedTensorType filter_type = RankedTensorType::get(
623         llvm::makeArrayRef(i64array), rewriter.getIntegerType(64));
624 
625     if (!getPaddingValuesFromPadType(
626             tf_pad,
627             tensorflow::FORMAT_NHWC,  // TFLite only supports this
628             1,                        // tensorflow::FORMAT_OHWI,
629             input_type, filter_type, stride, dilation, rewriter, pad))
630       return failure();
631   }
632 
633   CreateReplaceOpAndInfer<tosa::MaxPool2dOp>(
634       rewriter, op, output_type, tf_maxpool_op.input(), kernel, stride, pad);
635   return success();
636 }
637 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const638 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
639     Operation* op, PatternRewriter& rewriter) const {
640   auto tf_concatv2_op = cast<TF::ConcatV2Op>(op);
641   SmallVector<Value> values(tf_concatv2_op.values());
642 
643   ElementsAttr axis_elems;
644   if (!matchPattern(tf_concatv2_op.axis(), m_Constant(&axis_elems)))
645     return failure();
646 
647   int32_t axis = axis_elems.getValues<IntegerAttr>()[0].getInt();
648 
649   llvm::Optional<Value> result =
650       convertConcatV2Op(rewriter, op, tf_concatv2_op.getResult(), values, axis);
651 
652   if (!result) return failure();
653 
654   rewriter.replaceOp(op, {result.getValue()});
655 
656   return success();
657 }
658 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const659 LogicalResult ConvertTFReshapeOp::matchAndRewrite(
660     Operation* op, PatternRewriter& rewriter) const {
661   auto tf_reshape_op = cast<TF::ReshapeOp>(op);
662 
663   RankedTensorType output_type =
664       tf_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
665   // Not a ranked tensor output
666   if (!output_type) return failure();
667 
668   // Regular way to match tensor as element attribute doesn't always work
669   // use output_type.getShape() which is more stable
670   SmallVector<int64_t> shape_vals;
671   for (int i = 0; i < output_type.getShape().size(); i++) {
672     shape_vals.push_back(output_type.getShape()[i]);
673   }
674   ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals);
675 
676   CreateReplaceOpAndInfer<tosa::ReshapeOp>(rewriter, op, output_type,
677                                            tf_reshape_op.tensor(), shape_attr);
678   return success();
679 }
680 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const681 LogicalResult ConvertTFRankOp::matchAndRewrite(
682     Operation* op, PatternRewriter& rewriter) const {
683   auto tf_rank_op = cast<TF::RankOp>(op);
684 
685   RankedTensorType input_type =
686       tf_rank_op.input().getType().dyn_cast<RankedTensorType>();
687   if (!input_type) return failure();
688 
689   int32_t rank = input_type.getRank();
690 
691   RankedTensorType rank_type =
692       RankedTensorType::get({1}, rewriter.getIntegerType(32));
693   auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
694   auto rank_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
695                                                     rank_type, rank_attr);
696 
697   rewriter.replaceOp(op, {rank_const.getResult()});
698 
699   return success();
700 }
701 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const702 LogicalResult ConvertTFShapeOp::matchAndRewrite(
703     Operation* op, PatternRewriter& rewriter) const {
704   auto tf_shape_op = cast<TF::ShapeOp>(op);
705 
706   RankedTensorType output_type =
707       tf_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
708   // Not a ranked tensor output
709   if (!output_type) return failure();
710 
711   RankedTensorType input_type =
712       tf_shape_op.input().getType().dyn_cast<RankedTensorType>();
713   if (!input_type) return failure();
714 
715   auto input_shape = input_type.getShape();
716 
717   SmallVector<int32_t> shape_arr;
718   for (int i = 0; i < input_shape.size(); i++) {
719     shape_arr.emplace_back(input_shape[i]);
720   }
721 
722   RankedTensorType shape_type = RankedTensorType::get(
723       {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
724   auto shape_attr =
725       DenseElementsAttr::get(shape_type, llvm::makeArrayRef(shape_arr));
726   auto shape_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
727                                                      shape_type, shape_attr);
728 
729   rewriter.replaceOp(op, {shape_const.getResult()});
730 
731   return success();
732 }
733 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const734 LogicalResult ConvertTFExpandDimsOp::matchAndRewrite(
735     Operation* op, PatternRewriter& rewriter) const {
736   auto tf_expanddims_op = cast<TF::ExpandDimsOp>(op);
737 
738   llvm::Optional<Value> result =
739       convertExpandDimsOp(rewriter, op, tf_expanddims_op.getResult(),
740                           tf_expanddims_op.input(), tf_expanddims_op.dim());
741 
742   if (!result) return failure();
743 
744   rewriter.replaceOp(op, {result.getValue()});
745 
746   return success();
747 }
748 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const749 LogicalResult ConvertTFSqueezeOp::matchAndRewrite(
750     Operation* op, PatternRewriter& rewriter) const {
751   auto tf_squeeze_op = cast<TF::SqueezeOp>(op);
752 
753   // Copy squeeze_dims into int32_t array
754   auto squeeze_dims_attr = tf_squeeze_op.squeeze_dimsAttr();
755   SmallVector<int32_t> squeeze_dims;
756   for (auto& squeeze_dim : squeeze_dims_attr) {
757     squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
758   }
759 
760   llvm::Optional<Value> result =
761       convertSqueezeOp(rewriter, op, tf_squeeze_op.getResult(),
762                        tf_squeeze_op.input(), squeeze_dims);
763 
764   if (!result) return failure();
765 
766   rewriter.replaceOp(op, {result.getValue()});
767 
768   return success();
769 }
770 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const771 LogicalResult ConvertTFFillOp::matchAndRewrite(
772     Operation* op, PatternRewriter& rewriter) const {
773   auto tf_fill_op = cast<TF::FillOp>(op);
774 
775   RankedTensorType output_type =
776       tf_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
777   // Not a ranked tensor output
778   if (!output_type) return failure();
779 
780   ElementsAttr dims_elems;
781   if (!matchPattern(tf_fill_op.dims(), m_Constant(&dims_elems)))
782     return failure();
783   SmallVector<int64_t> dims_vals;
784   uint32_t total_size = 1;
785   for (int i = 0; i < dims_elems.getNumElements(); i++) {
786     dims_vals.push_back(dims_elems.getValues<IntegerAttr>()[i].getInt());
787     total_size *= dims_vals[i];
788   }
789 
790   ElementsAttr value_elem;
791   if (!matchPattern(tf_fill_op.value(), m_Constant(&value_elem)))
792     return failure();
793 
794   RankedTensorType fill_type = RankedTensorType::get(
795       ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
796   DenseElementsAttr fill_attr;
797 
798   // Convert to a compatible zero type
799   if (value_elem.getType().getElementType().isa<FloatType>()) {
800     SmallVector<float> fill_arr(
801         total_size,
802         value_elem.getValues<FloatAttr>()[0].getValue().convertToFloat());
803     fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
804   } else {
805     SmallVector<int32_t> fill_arr(
806         total_size,
807         value_elem.getValues<IntegerAttr>()[0].getValue().getLimitedValue());
808     fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
809   }
810   auto fill_const_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
811                                                        fill_type, fill_attr);
812   rewriter.replaceOp(op, {fill_const_op.getResult()});
813 
814   return success();
815 }
816 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const817 LogicalResult ConvertTFConv2DOp::matchAndRewrite(
818     Operation* op, PatternRewriter& rewriter) const {
819   auto tf_conv2d_op = cast<TF::Conv2DOp>(op);
820 
821   RankedTensorType filter_type =
822       tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
823   RankedTensorType output_type =
824       tf_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
825 
826   // Set up a zero attr for subsequent pattern replacement if required
827   auto bias_dim = filter_type.getShape().back();
828   RankedTensorType bias_type =
829       RankedTensorType::get({bias_dim}, filter_type.getElementType());
830   auto bias_attr = rewriter.getZeroAttr(bias_type);
831   auto bias = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), bias_type,
832                                               bias_attr.cast<ElementsAttr>());
833 
834   llvm::Optional<Value> result = convertTFConv2DCommon(
835       rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
836       bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
837       tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
838       tf_conv2d_op.data_format());
839 
840   if (!result) return failure();
841 
842   rewriter.replaceOp(op, {result.getValue()});
843 
844   return success();
845 }
846 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const847 LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite(
848     Operation* op, PatternRewriter& rewriter) const {
849   auto tf_dwconv2d_op = cast<TF::DepthwiseConv2dNativeOp>(op);
850 
851   RankedTensorType input_type =
852       tf_dwconv2d_op.input().getType().dyn_cast<RankedTensorType>();
853   RankedTensorType filter_type =
854       tf_dwconv2d_op.filter().getType().dyn_cast<RankedTensorType>();
855   RankedTensorType output_type =
856       tf_dwconv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
857   // Not a ranked tensor output
858   if (!input_type) return failure();
859   if (!output_type) return failure();
860 
861   // Set up a zero attr for subsequent pattern replacement if required
862   if (!filter_type) {
863     return op->emitOpError("DepthwiseConv2d: filter type unranked tensor");
864   }
865 
866   auto tmpAttr = tf_dwconv2d_op.data_formatAttr();
867   if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
868 
869   ArrayAttr stride;
870   ArrayAttr dilation;
871   ArrayAttr pad;
872   {
873     auto tmpAttr = tf_dwconv2d_op.strides();
874     if (!tmpAttr) {
875       stride = rewriter.getI64ArrayAttr({1, 1});
876     } else {
877       // Note: hardcoded to NHWC for now
878       int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
879       int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
880       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
881     }
882   }
883   {
884     auto tmpAttr = tf_dwconv2d_op.dilations();
885     if (!tmpAttr) {
886       dilation = rewriter.getI64ArrayAttr({1, 1});
887     } else {
888       // Note: hardcoded to NHWC for now
889       int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
890       int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
891       dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
892     }
893   }
894   {
895     tensorflow::Padding tf_pad;
896     if (!GetPaddingFromString(tf_dwconv2d_op.padding().str(), &tf_pad).ok())
897       return failure();
898 
899     tensorflow::TensorFormat data_format_tf;
900     if (!FormatFromString(tf_dwconv2d_op.data_format().str(), &data_format_tf))
901       return failure();
902 
903     if (tf_pad == tensorflow::Padding::EXPLICIT) {
904       pad = getPaddingValuesFromExplicitPadAttr(
905           tf_dwconv2d_op.explicit_paddings(), data_format_tf, rewriter);
906     } else {
907       if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
908                                        0,  // tensorflow::FORMAT_HWIO
909                                        input_type, filter_type, stride,
910                                        dilation, rewriter, pad))
911         return failure();
912     }
913   }
914 
915   auto filter_shape = filter_type.getShape();
916   auto bias_dim = filter_shape[2] * filter_shape[3];
917   RankedTensorType bias_type =
918       RankedTensorType::get({bias_dim}, filter_type.getElementType());
919   auto bias_attr = rewriter.getZeroAttr(bias_type);
920   auto bias = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), bias_type,
921                                               bias_attr.cast<ElementsAttr>());
922 
923   CreateReplaceOpAndInfer<tosa::DepthwiseConv2DOp>(
924       rewriter, op, output_type, tf_dwconv2d_op.input(),
925       tf_dwconv2d_op.filter(), bias, pad, stride, dilation);
926   return success();
927 }
928 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const929 LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite(
930     Operation* op, PatternRewriter& rewriter) const {
931   auto tf_conv_op = cast<TF::Conv2DBackpropInputOp>(op);
932 
933   RankedTensorType input_type =
934       tf_conv_op.out_backprop().getType().dyn_cast<RankedTensorType>();
935   RankedTensorType filter_type =
936       tf_conv_op.filter().getType().dyn_cast<RankedTensorType>();
937   RankedTensorType output_type =
938       tf_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
939   // Not a ranked tensor output
940   if (!input_type) return failure();
941   if (!filter_type) return failure();
942   if (!output_type) return failure();
943 
944   // Transpose [H, W, I, O] to [O, H, W, I]
945   auto filter_shape = filter_type.getShape();
946   SmallVector<int64_t, 4> a1_transpose_dims;
947   a1_transpose_dims.push_back(filter_shape[2]);
948   a1_transpose_dims.push_back(filter_shape[0]);
949   a1_transpose_dims.push_back(filter_shape[1]);
950   a1_transpose_dims.push_back(filter_shape[3]);
951   llvm::Optional<Value> a1_filter_transpose_perm = getConstTensor<int32_t>(
952       rewriter, op, /*vec=*/{2, 0, 1, 3}, /*shape=*/{4});
953 
954   if (!a1_filter_transpose_perm) return failure();
955 
956   auto a1_filter_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
957       rewriter, op->getLoc(),
958       RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
959                             filter_type.getElementType()),
960       tf_conv_op.filter(), a1_filter_transpose_perm.getValue());
961 
962   ArrayAttr stride;
963   ArrayAttr outpad;
964   ArrayAttr output_shape;
965   {
966     auto tmpAttr = tf_conv_op.strides();
967     if (!tmpAttr) {
968       stride = rewriter.getI64ArrayAttr({1, 1});
969     } else {
970       // Note: hardcoded to NHWC for now
971       int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
972       int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
973       stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
974     }
975   }
976   {
977     auto tmpAttr = tf_conv_op.dilations();
978     if (tmpAttr) {
979       // Note: hardcoded to NHWC for now
980       int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
981       int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
982       // TOSA transpose_conv2d does not support non-unit dilation
983       if (dilation_h != 1 || dilation_w != 1) return failure();
984     }
985   }
986   {
987     tensorflow::Padding tf_pad;
988     if (!GetPaddingFromString(tf_conv_op.padding().str(), &tf_pad).ok())
989       return failure();
990 
991     tensorflow::TensorFormat data_format_tf;
992     if (!FormatFromString(tf_conv_op.data_format().str(), &data_format_tf))
993       return failure();
994 
995     if (tf_pad == tensorflow::Padding::EXPLICIT) {
996       outpad = getPaddingValuesFromExplicitPadAttr(
997           tf_conv_op.explicit_paddings(), data_format_tf, rewriter);
998     } else {
999       if (!getTransposeConv2dPaddingValues(tf_pad, data_format_tf,
1000                                            0,  // tensorflow::FORMAT_HWIO,
1001                                            input_type, filter_type, output_type,
1002                                            stride, rewriter, outpad))
1003         return failure();
1004     }
1005   }
1006   {
1007     ElementsAttr output_shape_elems;
1008     // Match from input_sizes tensor first.
1009     if (matchPattern(tf_conv_op.input_sizes(),
1010                      m_Constant(&output_shape_elems))) {
1011       SmallVector<int64_t> shape_vec;
1012       for (int i = 0; i < output_shape_elems.getNumElements(); i++)
1013         shape_vec.push_back(
1014             output_shape_elems.getValues<IntegerAttr>()[i].getInt());
1015       output_shape = rewriter.getI64ArrayAttr(shape_vec);
1016     } else {
1017       // Use output tensor's shape otherwise.
1018       output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
1019     }
1020   }
1021 
1022   int output_channel = output_type.getShape()[3];
1023   SmallVector<float> vec(output_channel, 0.0f);
1024   llvm::Optional<Value> zero_bias =
1025       getConstTensor<float>(rewriter, op, vec, {output_channel});
1026 
1027   if (!zero_bias) return failure();
1028 
1029   CreateReplaceOpAndInfer<tosa::TransposeConv2DOp>(
1030       rewriter, op, output_type, tf_conv_op.out_backprop(),
1031       a1_filter_transpose_op.getResult(), zero_bias.getValue(), outpad, stride,
1032       output_shape);
1033 
1034   return success();
1035 }
1036 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1037 LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op,
1038                                               PatternRewriter& rewriter) const {
1039   auto tf_all_op = cast<TF::AllOp>(op);
1040 
1041   RankedTensorType output_type =
1042       tf_all_op.getResult().getType().dyn_cast<RankedTensorType>();
1043   if (!output_type) return failure();
1044 
1045   ElementsAttr axes_elems;
1046   if (!matchPattern(tf_all_op.reduction_indices(), m_Constant(&axes_elems)))
1047     return failure();
1048 
1049   llvm::Optional<Value> result = convertReduceAllOp(
1050       rewriter, op, output_type, tf_all_op.input(), axes_elems);
1051 
1052   if (!result) return failure();
1053 
1054   rewriter.replaceOp(op, {result.getValue()});
1055 
1056   return success();
1057 }
1058 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1059 LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op,
1060                                               PatternRewriter& rewriter) const {
1061   auto tf_any_op = cast<TF::AnyOp>(op);
1062 
1063   RankedTensorType output_type =
1064       tf_any_op.getResult().getType().dyn_cast<RankedTensorType>();
1065   if (!output_type) return failure();
1066 
1067   ElementsAttr axes_elems;
1068   if (!matchPattern(tf_any_op.reduction_indices(), m_Constant(&axes_elems)))
1069     return failure();
1070 
1071   llvm::Optional<Value> result = convertReduceAnyOp(
1072       rewriter, op, output_type, tf_any_op.input(), axes_elems);
1073 
1074   if (!result) return failure();
1075 
1076   rewriter.replaceOp(op, {result.getValue()});
1077 
1078   return success();
1079 }
1080 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1081 LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op,
1082                                               PatternRewriter& rewriter) const {
1083   auto tf_max_op = cast<TF::MaxOp>(op);
1084 
1085   RankedTensorType output_type =
1086       tf_max_op.getResult().getType().dyn_cast<RankedTensorType>();
1087   if (!output_type) return failure();
1088 
1089   ElementsAttr axes_elems;
1090   if (!matchPattern(tf_max_op.reduction_indices(), m_Constant(&axes_elems)))
1091     return failure();
1092 
1093   llvm::Optional<Value> result = convertReduceMaxOp(
1094       rewriter, op, output_type, tf_max_op.input(), axes_elems);
1095 
1096   if (!result) return failure();
1097 
1098   rewriter.replaceOp(op, {result.getValue()});
1099 
1100   return success();
1101 }
1102 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1103 LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op,
1104                                               PatternRewriter& rewriter) const {
1105   auto tf_min_op = cast<TF::MinOp>(op);
1106 
1107   RankedTensorType output_type =
1108       tf_min_op.getResult().getType().dyn_cast<RankedTensorType>();
1109   if (!output_type) return failure();
1110 
1111   ElementsAttr axes_elems;
1112   if (!matchPattern(tf_min_op.reduction_indices(), m_Constant(&axes_elems)))
1113     return failure();
1114 
1115   llvm::Optional<Value> result = convertReduceMinOp(
1116       rewriter, op, output_type, tf_min_op.input(), axes_elems);
1117 
1118   if (!result) return failure();
1119 
1120   rewriter.replaceOp(op, {result.getValue()});
1121 
1122   return success();
1123 }
1124 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1125 LogicalResult ConvertTFMeanOp::matchAndRewrite(
1126     Operation* op, PatternRewriter& rewriter) const {
1127   auto tf_mean_op = cast<TF::MeanOp>(op);
1128 
1129   RankedTensorType output_type =
1130       tf_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
1131   if (!output_type) return failure();
1132 
1133   ElementsAttr axes_elems;
1134   if (!matchPattern(tf_mean_op.reduction_indices(), m_Constant(&axes_elems)))
1135     return failure();
1136 
1137   llvm::Optional<Value> result = convertReduceMeanOp(
1138       rewriter, op, output_type, tf_mean_op.input(), axes_elems);
1139 
1140   if (!result) return failure();
1141 
1142   rewriter.replaceOp(op, {result.getValue()});
1143 
1144   return success();
1145 }
1146 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1147 LogicalResult ConvertTFProdOp::matchAndRewrite(
1148     Operation* op, PatternRewriter& rewriter) const {
1149   auto tf_prod_op = cast<TF::ProdOp>(op);
1150 
1151   RankedTensorType output_type =
1152       tf_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
1153   if (!output_type) return failure();
1154 
1155   ElementsAttr axes_elems;
1156   if (!matchPattern(tf_prod_op.reduction_indices(), m_Constant(&axes_elems)))
1157     return failure();
1158 
1159   llvm::Optional<Value> result = convertReduceProdOp(
1160       rewriter, op, output_type, tf_prod_op.input(), axes_elems);
1161 
1162   if (!result) return failure();
1163 
1164   rewriter.replaceOp(op, {result.getValue()});
1165 
1166   return success();
1167 }
1168 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1169 LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op,
1170                                               PatternRewriter& rewriter) const {
1171   auto tf_sum_op = cast<TF::SumOp>(op);
1172 
1173   RankedTensorType output_type =
1174       tf_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
1175   if (!output_type) return failure();
1176 
1177   ElementsAttr axes_elems;
1178   if (!matchPattern(tf_sum_op.reduction_indices(), m_Constant(&axes_elems)))
1179     return failure();
1180 
1181   llvm::Optional<Value> result = convertReduceSumOp(
1182       rewriter, op, output_type, tf_sum_op.input(), axes_elems);
1183 
1184   if (!result) return failure();
1185 
1186   rewriter.replaceOp(op, {result.getValue()});
1187 
1188   return success();
1189 }
1190 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1191 LogicalResult ConvertTFEluOp::matchAndRewrite(Operation* op,
1192                                               PatternRewriter& rewriter) const {
1193   auto tf_elu_op = cast<TF::EluOp>(op);
1194 
1195   llvm::Optional<Value> result =
1196       convertEluOp(rewriter, op, tf_elu_op.getResult(), tf_elu_op.features());
1197 
1198   if (!result) return failure();
1199 
1200   rewriter.replaceOp(op, {result.getValue()});
1201 
1202   return success();
1203 }
1204 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1205 LogicalResult ConvertTFSoftmaxOp::matchAndRewrite(
1206     Operation* op, PatternRewriter& rewriter) const {
1207   auto tf_softmax_op = cast<TF::SoftmaxOp>(op);
1208 
1209   llvm::Optional<Value> result =
1210       convertSoftmaxOp(rewriter, op, tf_softmax_op.getResult(),
1211                        tf_softmax_op.logits(), /*beta=*/1.0);
1212 
1213   if (!result) return failure();
1214 
1215   rewriter.replaceOp(op, {result.getValue()});
1216 
1217   return success();
1218 }
1219 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1220 LogicalResult ConvertTFLogSoftmaxOp::matchAndRewrite(
1221     Operation* op, PatternRewriter& rewriter) const {
1222   auto tf_logsoftmax_op = cast<TF::LogSoftmaxOp>(op);
1223 
1224   llvm::Optional<Value> result = convertLogSoftmaxOp(
1225       rewriter, op, tf_logsoftmax_op.getResult(), tf_logsoftmax_op.logits());
1226 
1227   if (!result) return failure();
1228 
1229   rewriter.replaceOp(op, {result.getValue()});
1230 
1231   return success();
1232 }
1233 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1234 LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite(
1235     Operation* op, PatternRewriter& rewriter) const {
1236   auto tf_batchnorm_op = cast<TF::FusedBatchNormOp>(op);
1237 
1238   RankedTensorType output_type =
1239       tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1240   // Not a ranked tensor output
1241   if (!output_type) return failure();
1242 
1243   // Lowering:
1244   // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
1245   //
1246   // shape_0 = ones(input.rank)
1247   // shape_0[input.rank-1] = input.shape[input.rank-1]
1248   // shape_1 = ones(1)
1249   //
1250   // bmean  = reshape(mean, shape_0)
1251   // bscale = reshape(scale, shape_0)
1252   // boffset= reshape(offset, shape_0)
1253   // beps   = reshape(epsilon, shape_1)
1254   //
1255   // op1 = sub(input, bmean)
1256   // op2 = add(var, beps)
1257   // op3 = rsqrt(op2)
1258   // bvar = reshape(op3, shape_0)
1259   // op4 = mul(op1, bvar)
1260   // op5 = mul(op4, bscale)
1261   // op6 = add(op5, boffset)
1262 
1263   RankedTensorType mean_type =
1264       tf_batchnorm_op.mean().getType().dyn_cast<RankedTensorType>();
1265   RankedTensorType variance_type =
1266       tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
1267   if (!variance_type || !mean_type) return failure();
1268 
1269   Value mean_val, variance_val;
1270 
1271   if (mean_type.getNumElements() == 0) {
1272     mean_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 0);
1273   } else {
1274     mean_val = tf_batchnorm_op.mean();
1275   }
1276 
1277   if (variance_type.getNumElements() == 0) {
1278     variance_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 1.0);
1279   } else {
1280     variance_val = tf_batchnorm_op.variance();
1281   }
1282 
1283   RankedTensorType epsilon_type =
1284       RankedTensorType::get({1}, variance_type.getElementType());
1285   auto epsilon_attr =
1286       DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
1287   auto epsilon_const = CreateOpAndInfer<tosa::ConstOp>(
1288       rewriter, op->getLoc(), epsilon_type, epsilon_attr);
1289 
1290   auto op1_sub_input_mean = CreateOpAndInfer<tosa::SubOp>(
1291       rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1292       tf_batchnorm_op.x(), mean_val);
1293 
1294   auto op2_add_var_epsilon = CreateOpAndInfer<tosa::AddOp>(
1295       rewriter, op->getLoc(), variance_val.getType(), variance_val,
1296       epsilon_const.getResult());
1297 
1298   auto op3_rsqrt_op2 = CreateOpAndInfer<tosa::RsqrtOp>(
1299       rewriter, op->getLoc(), variance_val.getType(),
1300       op2_add_var_epsilon.getResult());
1301 
1302   auto op4_mul_op1_op3 = CreateOpAndInfer<tosa::MulOp>(
1303       rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1304       op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
1305 
1306   auto op5_mul_op4_scale = CreateOpAndInfer<tosa::MulOp>(
1307       rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1308       op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
1309 
1310   auto op6_add_op5_offset = CreateOpAndInfer<tosa::AddOp>(
1311       rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1312       op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
1313 
1314   rewriter.replaceOp(op, {op6_add_op5_offset.getResult()});
1315   return success();
1316 }
1317 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1318 LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite(
1319     Operation* op, PatternRewriter& rewriter) const {
1320   auto tf_batchnorm_op = cast<TF::FusedBatchNormV3Op>(op);
1321 
1322   if (tf_batchnorm_op.is_training())
1323     return rewriter.notifyMatchFailure(
1324         op, "unable to lower when is_training is set");
1325 
1326   for (auto value : tf_batchnorm_op.getResults().drop_front(1)) {
1327     if (!value.use_empty()) {
1328       // Really we should compute this still and let it DCE but I can't find
1329       // the math.
1330       return rewriter.notifyMatchFailure(
1331           op, "lowering does not support aggregate statistics");
1332     }
1333   }
1334 
1335   RankedTensorType output_type =
1336       tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1337   // Not a ranked tensor output
1338   if (!output_type) return failure();
1339 
1340   // Lowering:
1341   // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
1342   // op1 = sub(input, mean)
1343   // op2 = add(var, epsilon)
1344   // op3 = rsqrt(op2)
1345   // op4 = mul(op1, op3)
1346   // op5 = mul(op4, scale)
1347   // op6 = add(op5, offset)
1348 
1349   auto op1_sub_input_mean = CreateOpAndInfer<tosa::SubOp>(
1350       rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1351       tf_batchnorm_op.x(), tf_batchnorm_op.mean());
1352 
1353   RankedTensorType variance_type =
1354       tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
1355   if (!variance_type) return failure();
1356 
1357   auto epsilon_type =
1358       RankedTensorType::get({1}, variance_type.getElementType());
1359   auto epsilon_attr =
1360       DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
1361   auto epsilon_const = CreateOpAndInfer<tosa::ConstOp>(
1362       rewriter, op->getLoc(), epsilon_type, epsilon_attr);
1363 
1364   auto op2_add_var_epsilon = CreateOpAndInfer<tosa::AddOp>(
1365       rewriter, op->getLoc(), tf_batchnorm_op.variance().getType(),
1366       tf_batchnorm_op.variance(), epsilon_const);
1367 
1368   auto op3_rsqrt_op2 = CreateOpAndInfer<tosa::RsqrtOp>(
1369       rewriter, op->getLoc(), tf_batchnorm_op.variance().getType(),
1370       op2_add_var_epsilon.getResult());
1371 
1372   auto op4_mul_op1_op3 = CreateOpAndInfer<tosa::MulOp>(
1373       rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1374       op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
1375 
1376   auto op5_mul_op4_scale = CreateOpAndInfer<tosa::MulOp>(
1377       rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1378       op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
1379 
1380   auto op6_add_op5_offset = CreateOpAndInfer<tosa::AddOp>(
1381       rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1382       op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
1383 
1384   llvm::SmallVector<Value> replacements = {
1385       op6_add_op5_offset.getResult(), tf_batchnorm_op.mean(),
1386       tf_batchnorm_op.variance(),
1387       // The last three are reserved spaces and have no purpose currently.
1388       tf_batchnorm_op.mean(), tf_batchnorm_op.variance(),
1389       tf_batchnorm_op.variance()};
1390   rewriter.replaceOp(op, replacements);
1391   return success();
1392 }
1393 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1394 LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
1395     Operation* op, PatternRewriter& rewriter) const {
1396   auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
1397 
1398   RankedTensorType output_type =
1399       tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
1400   // Not a ranked tensor output
1401   if (!output_type) return failure();
1402 
1403   auto add_op = CreateOpAndInfer<tosa::AddOp>(
1404       rewriter, op->getLoc(), output_type, tf_biasadd_op.value(),
1405       tf_biasadd_op.bias());
1406 
1407   rewriter.replaceOp(op, {add_op.getResult()});
1408   return success();
1409 }
1410 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1411 LogicalResult ConvertTFSliceOp::matchAndRewrite(
1412     Operation* op, PatternRewriter& rewriter) const {
1413   auto tf_slice_op = cast<TF::SliceOp>(op);
1414 
1415   RankedTensorType output_type =
1416       tf_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
1417   // Not a ranked tensor output
1418   if (!output_type) return failure();
1419 
1420   ElementsAttr begin_elems, size_elems;
1421 
1422   SmallVector<int64_t> begin_vals, size_vals;
1423 
1424   // Assuming begin is always compile-time constant
1425   if (!matchPattern(tf_slice_op.begin(), m_Constant(&begin_elems))) {
1426     return op->emitOpError("TF::Slice error: begin is not constant");
1427   }
1428 
1429   for (int i = 0; i < begin_elems.getNumElements(); i++)
1430     begin_vals.push_back(begin_elems.getValues<IntegerAttr>()[i].getInt());
1431 
1432   // Try to match size as compile-time constant first,
1433   // if this fails, use the output tensor shape instead.
1434   if (matchPattern(tf_slice_op.size(), m_Constant(&size_elems))) {
1435     for (int i = 0; i < size_elems.getNumElements(); i++)
1436       size_vals.push_back(size_elems.getValues<IntegerAttr>()[i].getInt());
1437   } else {
1438     size_vals.assign(output_type.getShape().begin(),
1439                      output_type.getShape().end());
1440   }
1441 
1442   ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1443   ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1444 
1445   CreateReplaceOpAndInfer<tosa::SliceOp>(rewriter, op, output_type,
1446                                          tf_slice_op.input(), begin, size);
1447   return success();
1448 }
1449 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1450 LogicalResult ConvertTFTileOp::matchAndRewrite(
1451     Operation* op, PatternRewriter& rewriter) const {
1452   auto tf_tile_op = cast<TF::TileOp>(op);
1453 
1454   RankedTensorType output_type =
1455       tf_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
1456   // Not a ranked tensor output
1457   if (!output_type) return failure();
1458 
1459   ElementsAttr multiples_elems;
1460   if (!matchPattern(tf_tile_op.multiples(), m_Constant(&multiples_elems)))
1461     return failure();
1462   SmallVector<int64_t> multiples_vals;
1463   for (int i = 0; i < multiples_elems.getNumElements(); i++)
1464     multiples_vals.push_back(
1465         multiples_elems.getValues<IntegerAttr>()[i].getInt());
1466 
1467   ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
1468 
1469   CreateReplaceOpAndInfer<tosa::TileOp>(rewriter, op, output_type,
1470                                         tf_tile_op.input(), multiples_attr);
1471 
1472   return success();
1473 }
1474 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1475 LogicalResult ConvertTFTransposeOp::matchAndRewrite(
1476     Operation* op, PatternRewriter& rewriter) const {
1477   auto tf_transpose_op = cast<TF::TransposeOp>(op);
1478 
1479   TensorType output_type =
1480       tf_transpose_op.getResult().getType().dyn_cast<TensorType>();
1481   // Not a ranked tensor output
1482   if (!output_type) {
1483     return failure();
1484   }
1485 
1486   CreateReplaceOpAndInfer<tosa::TransposeOp>(
1487       rewriter, op, output_type, tf_transpose_op.x(), tf_transpose_op.perm());
1488 
1489   return success();
1490 }
1491 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1492 LogicalResult ConvertTFPackOp::matchAndRewrite(
1493     Operation* op, PatternRewriter& rewriter) const {
1494   auto tf_pack_op = cast<TF::PackOp>(op);
1495 
1496   SmallVector<Value> inputs(tf_pack_op.values());
1497 
1498   assert(inputs.size() >= 2);
1499 
1500   IntegerAttr axis_attr = tf_pack_op.axisAttr();
1501   if (!axis_attr) axis_attr = rewriter.getI64IntegerAttr(0);
1502 
1503   int32_t axis_i32 = axis_attr.getInt();
1504 
1505   llvm::Optional<Value> result =
1506       convertPackOp(rewriter, op, tf_pack_op.getResult(), inputs, axis_i32);
1507 
1508   if (!result) return failure();
1509 
1510   rewriter.replaceOp(op, {result.getValue()});
1511 
1512   return success();
1513 }
1514 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1515 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
1516     Operation* op, PatternRewriter& rewriter) const {
1517   auto tf_unpack_op = cast<TF::UnpackOp>(op);
1518 
1519   IntegerAttr axis_attr;
1520   {
1521     auto tmpAttr = tf_unpack_op.axisAttr();
1522     if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
1523     axis_attr = tmpAttr;
1524   }
1525   int32_t axis_i32 = axis_attr.getInt();
1526 
1527   llvm::Optional<SmallVector<Value>> results =
1528       convertUnpackOp(rewriter, op, tf_unpack_op.value(), axis_i32);
1529 
1530   if (!results) return failure();
1531 
1532   rewriter.replaceOp(op, results.getValue());
1533 
1534   return success();
1535 }
1536 
1537 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1538 LogicalResult ConvertTFSplitOp::matchAndRewrite(
1539     Operation* op, PatternRewriter& rewriter) const {
1540   auto tf_split_op = cast<TF::SplitOp>(op);
1541 
1542   // Get the number of splits
1543   int32_t num_split = -1;
1544 
1545   auto range = tf_split_op.getODSResults(0);
1546   num_split = std::distance(range.begin(), range.end());
1547 
1548   // Get the axis
1549   int32_t axis = 0;
1550   ElementsAttr axisAttrElems;
1551   if (matchPattern(tf_split_op.split_dim(), m_Constant(&axisAttrElems))) {
1552     axis = axisAttrElems.getValues<IntegerAttr>()[0].getInt();
1553   }
1554 
1555   llvm::Optional<SmallVector<Value>> results =
1556       convertSplitOp(rewriter, op, tf_split_op.getResult(0),
1557                      tf_split_op.value(), num_split, axis);
1558 
1559   if (!results) return failure();
1560 
1561   rewriter.replaceOp(op, results.getValue());
1562 
1563   return success();
1564 }
1565 
1566 // TFSplitV op splits based on a vector of sizes
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1567 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
1568     Operation* op, PatternRewriter& rewriter) const {
1569   auto tf_splitv_op = cast<TF::SplitVOp>(op);
1570 
1571   // Get the size_splits array
1572   SmallVector<int32_t> size_split;
1573   ElementsAttr size_split_elems;
1574   if (!matchPattern(tf_splitv_op.size_splits(),
1575                     m_Constant(&size_split_elems))) {
1576     return failure();
1577   }
1578 
1579   for (int i = 0; i < size_split_elems.getNumElements(); i++) {
1580     size_split.push_back(size_split_elems.getValues<IntegerAttr>()[i].getInt());
1581   }
1582 
1583   // Get the axis
1584   ElementsAttr axisAttrElems;
1585   if (!matchPattern(tf_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
1586     return op->emitOpError("Cannot read split_dim elems");
1587   }
1588 
1589   int32_t axis = axisAttrElems.getValues<IntegerAttr>()[0].getInt();
1590 
1591   llvm::Optional<SmallVector<Value>> results =
1592       convertSplitVOp(rewriter, op, tf_splitv_op.getResult(0),
1593                       tf_splitv_op.value(), size_split, axis);
1594 
1595   if (!results) return failure();
1596 
1597   rewriter.replaceOp(op, results.getValue());
1598 
1599   return success();
1600 }
1601 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1602 LogicalResult ConvertTFLessOp::matchAndRewrite(
1603     Operation* op, PatternRewriter& rewriter) const {
1604   auto tf_less_op = cast<TF::LessOp>(op);
1605 
1606   TensorType output_type =
1607       tf_less_op.getResult().getType().dyn_cast<TensorType>();
1608   // Not a ranked tensor output
1609   if (!output_type) return failure();
1610 
1611   // less(x, y) is not(greater_equal(x, y))
1612   auto greater_equal_op = CreateOpAndInfer<tosa::GreaterEqualOp>(
1613       rewriter, op->getLoc(), output_type, tf_less_op.x(), tf_less_op.y());
1614 
1615   auto not_op = CreateOpAndInfer<tosa::LogicalNotOp>(
1616       rewriter, op->getLoc(), output_type, greater_equal_op.getResult());
1617 
1618   rewriter.replaceOp(op, {not_op.getResult()});
1619   return success();
1620 }
1621 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1622 LogicalResult ConvertTFLessEqualOp::matchAndRewrite(
1623     Operation* op, PatternRewriter& rewriter) const {
1624   auto tf_less_equal_op = cast<TF::LessEqualOp>(op);
1625 
1626   TensorType output_type =
1627       tf_less_equal_op.getResult().getType().dyn_cast<TensorType>();
1628   // Not a ranked tensor output
1629   if (!output_type) return failure();
1630 
1631   // less_equal(x, y) is not(greater(x, y))
1632   auto greater_op = CreateOpAndInfer<tosa::GreaterOp>(
1633       rewriter, op->getLoc(), output_type, tf_less_equal_op.x(),
1634       tf_less_equal_op.y());
1635   auto not_op = CreateOpAndInfer<tosa::LogicalNotOp>(
1636       rewriter, op->getLoc(), output_type, greater_op.getResult());
1637 
1638   rewriter.replaceOp(op, {not_op.getResult()});
1639   return success();
1640 }
1641 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1642 LogicalResult ConvertTFPadOp::matchAndRewrite(Operation* op,
1643                                               PatternRewriter& rewriter) const {
1644   auto tf_pad_op = cast<TF::PadOp>(op);
1645 
1646   TensorType output_type =
1647       tf_pad_op.getResult().getType().dyn_cast<TensorType>();
1648   // Not a ranked tensor output
1649   if (!output_type) return failure();
1650 
1651   auto pad_op =
1652       CreateOpAndInfer<tosa::PadOp>(rewriter, op->getLoc(), output_type,
1653                                     tf_pad_op.input(), tf_pad_op.paddings());
1654 
1655   rewriter.replaceOp(op, {pad_op.getResult()});
1656   return success();
1657 }
1658 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1659 LogicalResult ConvertTFResizeBilinearOp::matchAndRewrite(
1660     Operation* op, PatternRewriter& rewriter) const {
1661   auto tf_resize_op = cast<TF::ResizeBilinearOp>(op);
1662 
1663   RankedTensorType output_type =
1664       tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
1665   // Not a ranked tensor output
1666   if (!output_type) return failure();
1667 
1668   llvm::Optional<Value> result = convertResizeOp(
1669       rewriter, op, output_type, tf_resize_op.images(), StringRef("BILINEAR"),
1670       tf_resize_op.align_cornersAttr().getValue(),
1671       tf_resize_op.half_pixel_centersAttr().getValue());
1672 
1673   if (!result) return failure();
1674 
1675   rewriter.replaceOp(op, {result.getValue()});
1676 
1677   return success();
1678 }
1679 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1680 LogicalResult ConvertTFResizeNearestNeighborOp::matchAndRewrite(
1681     Operation* op, PatternRewriter& rewriter) const {
1682   auto tf_resize_op = cast<TF::ResizeNearestNeighborOp>(op);
1683 
1684   RankedTensorType output_type =
1685       tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
1686   // Not a ranked tensor output
1687   if (!output_type) return failure();
1688 
1689   llvm::Optional<Value> result =
1690       convertResizeOp(rewriter, op, output_type, tf_resize_op.images(),
1691                       StringRef("NEAREST_NEIGHBOR"),
1692                       tf_resize_op.align_cornersAttr().getValue(),
1693                       tf_resize_op.half_pixel_centersAttr().getValue());
1694 
1695   if (!result) return failure();
1696 
1697   rewriter.replaceOp(op, {result.getValue()});
1698 
1699   return success();
1700 }
1701 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1702 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
1703     Operation* op, PatternRewriter& rewriter) const {
1704   auto tf_matmul_op = cast<TF::MatMulOp>(op);
1705 
1706   RankedTensorType a_type =
1707       tf_matmul_op.a().getType().dyn_cast<RankedTensorType>();
1708   RankedTensorType b_type =
1709       tf_matmul_op.b().getType().dyn_cast<RankedTensorType>();
1710   RankedTensorType output_type =
1711       tf_matmul_op.getResult().getType().dyn_cast<RankedTensorType>();
1712 
1713   if (!(a_type && b_type && output_type)) {
1714     return op->emitOpError("MatMul: a/b/output not ranked tensors");
1715   }
1716 
1717   if (a_type.getRank() != b_type.getRank() ||
1718       a_type.getRank() != output_type.getRank()) {
1719     return op->emitOpError("MatMul: a/b/output rank must match");
1720   }
1721 
1722   // Can only handle rank 2 tensors for tf.MatMul.
1723   // Cases with rank > 2 tensors should be handled by tf.BatchMatMul or
1724   // tf.BatchMatMulV2
1725   if (a_type.getRank() != 2) {
1726     return op->emitOpError("MatMul: a/b/output rank must be 2");
1727   }
1728 
1729   SmallVector<int64_t, 3> batch_a_shape(
1730       {1, a_type.getShape()[0], a_type.getShape()[1]});
1731   SmallVector<int64_t, 3> batch_b_shape(
1732       {1, b_type.getShape()[0], b_type.getShape()[1]});
1733   SmallVector<int64_t, 3> batch_output_shape(
1734       {1, output_type.getShape()[0], output_type.getShape()[1]});
1735 
1736   RankedTensorType batch_a_type =
1737       RankedTensorType::get(batch_a_shape, a_type.getElementType());
1738   RankedTensorType batch_b_type =
1739       RankedTensorType::get(batch_b_shape, b_type.getElementType());
1740   RankedTensorType batch_output_type =
1741       RankedTensorType::get(batch_output_shape, output_type.getElementType());
1742 
1743   // Need to reshape input and output since TOSA matmul only supports
1744   // [N, H, C] * [N, C, W] -> [N, H, W].
1745   auto op1_reshape_a = CreateOpAndInfer<tosa::ReshapeOp>(
1746       rewriter, op->getLoc(), batch_a_type, tf_matmul_op.a(),
1747       rewriter.getI64ArrayAttr(batch_a_shape));
1748 
1749   auto op2_reshape_b = CreateOpAndInfer<tosa::ReshapeOp>(
1750       rewriter, op->getLoc(), batch_b_type, tf_matmul_op.b(),
1751       rewriter.getI64ArrayAttr(batch_b_shape));
1752 
1753   auto op3_matmul_op1_op2 = CreateOpAndInfer<tosa::MatMulOp>(
1754       rewriter, op->getLoc(), batch_output_type, op1_reshape_a.getResult(),
1755       op2_reshape_b.getResult());
1756 
1757   CreateReplaceOpAndInfer<tosa::ReshapeOp>(
1758       rewriter, op, output_type, op3_matmul_op1_op2.getResult(),
1759       rewriter.getI64ArrayAttr(output_type.getShape()));
1760 
1761   return success();
1762 }
1763 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1764 LogicalResult ConvertTFGatherOp::matchAndRewrite(
1765     Operation* op, PatternRewriter& rewriter) const {
1766   auto tf_gather_op = cast<TF::GatherOp>(op);
1767 
1768   // tf.Gather is equivalent to tf.GatherV2 with batch_dims = 0, axis = 0
1769   int32_t batch_dims = 0;
1770   int32_t axis = 0;
1771 
1772   llvm::Optional<Value> result = convertGatherOp(
1773       rewriter, op, tf_gather_op.getResult(), tf_gather_op.params(),
1774       tf_gather_op.indices(), batch_dims, axis);
1775 
1776   if (!result) return failure();
1777 
1778   rewriter.replaceOp(op, {result.getValue()});
1779 
1780   return success();
1781 }
1782 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1783 LogicalResult ConvertTFGatherV2Op::matchAndRewrite(
1784     Operation* op, PatternRewriter& rewriter) const {
1785   auto tf_gather_op = cast<TF::GatherV2Op>(op);
1786 
1787   // Axis is a tensor.  Pull out the one integer value.
1788   ElementsAttr axis_elem;
1789   if (!matchPattern(tf_gather_op.axis(), m_Constant(&axis_elem)))
1790     return failure();
1791   assert(axis_elem.getNumElements() == 1);
1792 
1793   int32_t axis = axis_elem.getValues<IntegerAttr>()[0].getInt();
1794   int32_t batch_dims = tf_gather_op.batch_dimsAttr().getInt();
1795 
1796   llvm::Optional<Value> result = convertGatherOp(
1797       rewriter, op, tf_gather_op.getResult(), tf_gather_op.params(),
1798       tf_gather_op.indices(), batch_dims, axis);
1799 
1800   if (!result) return failure();
1801 
1802   rewriter.replaceOp(op, {result.getValue()});
1803 
1804   return success();
1805 }
1806 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1807 LogicalResult ConvertTFGatherNdOp::matchAndRewrite(
1808     Operation* op, PatternRewriter& rewriter) const {
1809   auto tf_gathernd_op = cast<TF::GatherNdOp>(op);
1810 
1811   llvm::Optional<Value> result =
1812       convertGatherNdOp(rewriter, op, tf_gathernd_op.getResult(),
1813                         tf_gathernd_op.params(), tf_gathernd_op.indices());
1814 
1815   if (!result) return failure();
1816 
1817   rewriter.replaceOp(op, {result.getValue()});
1818 
1819   return success();
1820 }
1821 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1822 LogicalResult ConvertTFSelectV2Op::matchAndRewrite(
1823     Operation* op, PatternRewriter& rewriter) const {
1824   auto tf_sel_op = cast<TF::SelectV2Op>(op);
1825 
1826   llvm::Optional<Value> result =
1827       convertSelectOp(rewriter, op, tf_sel_op.getResult(),
1828                       tf_sel_op.condition(), tf_sel_op.t(), tf_sel_op.e());
1829 
1830   if (!result) return failure();
1831 
1832   rewriter.replaceOp(op, {result.getValue()});
1833 
1834   return success();
1835 }
1836 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1837 LogicalResult ConvertTFSpaceToDepthOp::matchAndRewrite(
1838     Operation* op, PatternRewriter& rewriter) const {
1839   auto tf_s2d_op = cast<TF::SpaceToDepthOp>(op);
1840 
1841   llvm::Optional<Value> result = convertSpaceToDepthOp(
1842       rewriter, op, tf_s2d_op.getResult(), tf_s2d_op.input(),
1843       tf_s2d_op.block_sizeAttr(), tf_s2d_op.data_formatAttr());
1844 
1845   if (!result) return failure();
1846 
1847   rewriter.replaceOp(op, {result.getValue()});
1848 
1849   return success();
1850 }
1851 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1852 LogicalResult ConvertTFDepthToSpaceOp::matchAndRewrite(
1853     Operation* op, PatternRewriter& rewriter) const {
1854   auto tf_d2s_op = cast<TF::DepthToSpaceOp>(op);
1855 
1856   llvm::Optional<Value> result = convertDepthToSpaceOp(
1857       rewriter, op, tf_d2s_op.getResult(), tf_d2s_op.input(),
1858       tf_d2s_op.block_sizeAttr(), tf_d2s_op.data_formatAttr());
1859 
1860   if (!result) return failure();
1861 
1862   rewriter.replaceOp(op, {result.getValue()});
1863 
1864   return success();
1865 }
1866 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1867 LogicalResult ConvertTFSpaceToBatchNDOp::matchAndRewrite(
1868     Operation* op, PatternRewriter& rewriter) const {
1869   auto tf_s2b_op = cast<TF::SpaceToBatchNDOp>(op);
1870 
1871   llvm::Optional<Value> result = convertSpaceToBatchNDOp(
1872       rewriter, op, tf_s2b_op.getResult(), tf_s2b_op.input(),
1873       tf_s2b_op.block_shape(), tf_s2b_op.paddings());
1874   if (!result) return failure();
1875 
1876   rewriter.replaceOp(op, {result.getValue()});
1877 
1878   return success();
1879 }
1880 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1881 LogicalResult ConvertTFBatchToSpaceNDOp::matchAndRewrite(
1882     Operation* op, PatternRewriter& rewriter) const {
1883   auto tf_b2s_op = cast<TF::BatchToSpaceNDOp>(op);
1884 
1885   llvm::Optional<Value> result = convertBatchToSpaceNDOp(
1886       rewriter, op, tf_b2s_op.getResult(), tf_b2s_op.input(),
1887       tf_b2s_op.block_shape(), tf_b2s_op.crops());
1888 
1889   if (!result) return failure();
1890 
1891   rewriter.replaceOp(op, {result.getValue()});
1892 
1893   return success();
1894 }
1895 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1896 LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
1897     Operation* op, PatternRewriter& rewriter) const {
1898   auto tf_ss_op = cast<TF::StridedSliceOp>(op);
1899 
1900   llvm::Optional<Value> result = convertStridedSliceOp(
1901       rewriter, op, tf_ss_op.getResult(), tf_ss_op.input(), tf_ss_op.begin(),
1902       tf_ss_op.end(), tf_ss_op.strides(), tf_ss_op.begin_maskAttr().getInt(),
1903       tf_ss_op.end_maskAttr().getInt(), tf_ss_op.ellipsis_maskAttr().getInt(),
1904       tf_ss_op.new_axis_maskAttr().getInt(),
1905       tf_ss_op.shrink_axis_maskAttr().getInt());
1906 
1907   if (!result) return failure();
1908 
1909   rewriter.replaceOp(op, {result.getValue()});
1910 
1911   return success();
1912 }
1913 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1914 LogicalResult ConvertTFZerosLikeOp::matchAndRewrite(
1915     Operation* op, PatternRewriter& rewriter) const {
1916   auto tf_zeroslike_op = cast<TF::ZerosLikeOp>(op);
1917 
1918   llvm::Optional<Value> result = convertZerosLikeOp(
1919       rewriter, op, tf_zeroslike_op.getResult(), tf_zeroslike_op.x());
1920 
1921   if (!result) return failure();
1922 
1923   rewriter.replaceOp(op, {result.getValue()});
1924 
1925   return success();
1926 }
1927 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1928 LogicalResult ConvertTFSigmoidOp::matchAndRewrite(
1929     Operation* op, PatternRewriter& rewriter) const {
1930   auto tf_sigmoid_op = cast<TF::SigmoidOp>(op);
1931   TensorType output_type =
1932       tf_sigmoid_op.getResult().getType().dyn_cast<TensorType>();
1933   if (!output_type) return failure();
1934 
1935   CreateReplaceOpAndInfer<tosa::SigmoidOp>(rewriter, op, output_type,
1936                                            tf_sigmoid_op.x());
1937 
1938   return success();
1939 }
1940 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1941 LogicalResult ConvertTFTanhOp::matchAndRewrite(
1942     Operation* op, PatternRewriter& rewriter) const {
1943   auto tf_tanh_op = cast<TF::TanhOp>(op);
1944   TensorType output_type =
1945       tf_tanh_op.getResult().getType().dyn_cast<TensorType>();
1946   if (!output_type) return failure();
1947 
1948   CreateReplaceOpAndInfer<tosa::TanhOp>(rewriter, op, output_type,
1949                                         tf_tanh_op.x());
1950 
1951   return success();
1952 }
1953 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1954 LogicalResult ConvertTFLeakyReluOp::matchAndRewrite(
1955     Operation* op, PatternRewriter& rewriter) const {
1956   auto tf_leakyrelu_op = cast<TF::LeakyReluOp>(op);
1957   TensorType output_type =
1958       tf_leakyrelu_op.getResult().getType().dyn_cast<TensorType>();
1959   if (!output_type) return failure();
1960 
1961   // Implement LeakyRelu as element-wise:
1962   //   out = x > 0 ? x : alpha * x
1963   //
1964   // In TOSA ops:
1965   //
1966   //   const_zero = constant(0)
1967   //   a1 = mul(x, alpha)
1968   //   a2 = greater_equal(x, const_zero)
1969   //   out = select(a2, x, a1)
1970   //
1971   // If alpha can be constrained to 0.0 <= alpha <= 1.0, then
1972   // an alternative simpler lowering could be implemented with:
1973   //
1974   //   max(mul(x, alapha), x)
1975   //
1976   // But this alternative is not robust unless alpha meets those constraints.
1977 
1978   if (!output_type.getElementType().isF32()) {
1979     op->emitOpError("ConvertTFLeakyReluOp: only support F32");
1980     return failure();
1981   }
1982 
1983   FloatAttr tmpAttr = tf_leakyrelu_op.alphaAttr();
1984   // There is disagreement between the MLIR .td defaults and TF
1985   // documentation on 0.2 vs 0.3, but 0.2 will be used here.
1986   double alpha = 0.2;
1987 
1988   if (tmpAttr) {
1989     alpha = tmpAttr.getValueAsDouble();
1990   }
1991 
1992   Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0);
1993 
1994   auto a1_mul = CreateOpAndInfer<tosa::MulOp>(
1995       rewriter, op->getLoc(), output_type, tf_leakyrelu_op.features(),
1996       getTosaConstTensorSingleF32(rewriter, op, alpha), 0);
1997 
1998   auto a2_ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
1999       rewriter, op->getLoc(), UnrankedTensorType::get(rewriter.getI1Type()),
2000       tf_leakyrelu_op.features(), const_zero);
2001 
2002   auto a3_select = CreateOpAndInfer<tosa::SelectOp>(
2003       rewriter, op->getLoc(), output_type, a2_ge, tf_leakyrelu_op.features(),
2004       a1_mul.getResult());
2005 
2006   rewriter.replaceOp(op, {a3_select.getResult()});
2007 
2008   return success();
2009 }
2010 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2011 LogicalResult ConvertTFNegOp::matchAndRewrite(Operation* op,
2012                                               PatternRewriter& rewriter) const {
2013   auto tf_neg_op = cast<TF::NegOp>(op);
2014   TensorType output_type =
2015       tf_neg_op.getResult().getType().dyn_cast<TensorType>();
2016   if (!output_type) return failure();
2017 
2018   CreateReplaceOpAndInfer<tosa::NegateOp>(rewriter, op, output_type,
2019                                           tf_neg_op.x());
2020 
2021   return success();
2022 }
2023 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2024 LogicalResult ConvertTFStopGradientOp::matchAndRewrite(
2025     Operation* op, PatternRewriter& rewriter) const {
2026   auto tf_stopgrad_op = cast<TF::StopGradientOp>(op);
2027   TensorType output_type =
2028       tf_stopgrad_op.getResult().getType().dyn_cast<TensorType>();
2029   if (!output_type) return failure();
2030 
2031   CreateReplaceOpAndInfer<tosa::IdentityOp>(rewriter, op, output_type,
2032                                             tf_stopgrad_op.input());
2033 
2034   return success();
2035 }
2036 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2037 LogicalResult ConvertTFReverseV2Op::matchAndRewrite(
2038     Operation* op, PatternRewriter& rewriter) const {
2039   auto tf_reverse_op = cast<TF::ReverseV2Op>(op);
2040   RankedTensorType input_type =
2041       tf_reverse_op.tensor().getType().dyn_cast<RankedTensorType>();
2042   TensorType output_type =
2043       tf_reverse_op.getResult().getType().dyn_cast<TensorType>();
2044   if (!input_type || !output_type) return failure();
2045 
2046   ElementsAttr axis_elems;
2047   if (!matchPattern(tf_reverse_op.axis(), m_Constant(&axis_elems)))
2048     return failure();
2049 
2050   auto input_rank = input_type.getShape().size();
2051   Value val = tf_reverse_op.tensor();
2052   if (axis_elems.getNumElements() == 0) {
2053     auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
2054         rewriter, op->getLoc(), output_type, val);
2055     val = identity_op.getResult();
2056   } else {
2057     for (int i = 0; i < axis_elems.getNumElements(); i++) {
2058       int64_t axis_val = axis_elems.getValues<IntegerAttr>()[i].getInt();
2059       if (axis_val < 0) axis_val += input_rank;
2060       auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2061       auto reverse_op = CreateOpAndInfer<tosa::ReverseOp>(
2062           rewriter, op->getLoc(), output_type, val, axis_attr);
2063 
2064       val = reverse_op.getResult();
2065     }
2066   }
2067 
2068   rewriter.replaceOp(op, {val});
2069 
2070   return success();
2071 }
2072 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2073 LogicalResult ConvertTFFakeQuantWithMinMaxArgsOp::matchAndRewrite(
2074     Operation* op, PatternRewriter& rewriter) const {
2075   auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxArgsOp>(op);
2076 
2077   TensorType output_type =
2078       tf_fakequant_op.getResult().getType().dyn_cast<TensorType>();
2079   // Not a tensor output
2080   if (!output_type) return failure();
2081 
2082   llvm::Optional<Value> result =
2083       convertFakeQuantOp(rewriter, op, output_type, tf_fakequant_op.inputs(),
2084                          tf_fakequant_op.minAttr().getValueAsDouble(),
2085                          tf_fakequant_op.maxAttr().getValueAsDouble(),
2086                          tf_fakequant_op.num_bitsAttr().getInt(),
2087                          tf_fakequant_op.narrow_rangeAttr().getValue());
2088 
2089   if (!result) return failure();
2090 
2091   rewriter.replaceOp(op, {result.getValue()});
2092 
2093   return success();
2094 }
2095 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2096 LogicalResult ConvertTFFakeQuantWithMinMaxVarsOp::matchAndRewrite(
2097     Operation* op, PatternRewriter& rewriter) const {
2098   auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
2099 
2100   TensorType output_type =
2101       tf_fakequant_op.getResult().getType().dyn_cast<TensorType>();
2102   // Not a tensor output
2103   if (!output_type) return failure();
2104 
2105   // Only support min/max that can be matched at compile time
2106   ElementsAttr min_elems, max_elems;
2107   if (!matchPattern(tf_fakequant_op.min(), m_Constant(&min_elems)))
2108     return failure();
2109 
2110   if (!matchPattern(tf_fakequant_op.max(), m_Constant(&max_elems)))
2111     return failure();
2112 
2113   if (min_elems.getNumElements() != 1 && max_elems.getNumElements() != 1)
2114     return failure();
2115 
2116   int64_t min_val = min_elems.getValues<IntegerAttr>()[0].getInt();
2117   int64_t max_val = max_elems.getValues<IntegerAttr>()[0].getInt();
2118 
2119   llvm::Optional<Value> result = convertFakeQuantOp(
2120       rewriter, op, output_type, tf_fakequant_op.inputs(), min_val, max_val,
2121       tf_fakequant_op.num_bitsAttr().getInt(),
2122       tf_fakequant_op.narrow_rangeAttr().getValue());
2123 
2124   if (!result) return failure();
2125 
2126   rewriter.replaceOp(op, {result.getValue()});
2127 
2128   return success();
2129 }
2130 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2131 LogicalResult ConvertTFLeftShiftOp::matchAndRewrite(
2132     Operation* op, PatternRewriter& rewriter) const {
2133   auto tf_left_shift_op = cast<TF::LeftShiftOp>(op);
2134 
2135   TensorType output_type =
2136       tf_left_shift_op.getResult().getType().dyn_cast<TensorType>();
2137   if (!output_type) return failure();
2138 
2139   CreateReplaceOpAndInfer<tosa::LogicalLeftShiftOp>(
2140       rewriter, op, output_type, tf_left_shift_op.x(), tf_left_shift_op.y());
2141 
2142   return success();
2143 }
2144 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2145 LogicalResult ConvertTFRightShiftOp::matchAndRewrite(
2146     Operation* op, PatternRewriter& rewriter) const {
2147   // Performs a logical shift for unsigned integer types, and an arithmetic
2148   // shift for signed integer types.
2149   auto tf_right_shift_op = cast<TF::RightShiftOp>(op);
2150 
2151   TensorType output_type =
2152       tf_right_shift_op.getResult().getType().dyn_cast<TensorType>();
2153   if (!output_type) return failure();
2154 
2155   Type output_element_type = output_type.getElementType();
2156 
2157   bool is_signed = false;
2158   if (!output_element_type.isUnsignedInteger()) is_signed = true;
2159 
2160   if (is_signed) {
2161     CreateReplaceOpAndInfer<tosa::ArithmeticRightShiftOp>(
2162         rewriter, op, output_type, tf_right_shift_op.x(), tf_right_shift_op.y(),
2163         false);
2164   } else {
2165     CreateReplaceOpAndInfer<tosa::LogicalRightShiftOp>(
2166         rewriter, op, output_type, tf_right_shift_op.x(),
2167         tf_right_shift_op.y());
2168   }
2169 
2170   return success();
2171 }
2172 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2173 LogicalResult ConvertTFOneHotOp::matchAndRewrite(
2174     Operation* op, PatternRewriter& rewriter) const {
2175   auto tf_one_hot_op = cast<TF::OneHotOp>(op);
2176 
2177   ElementsAttr depth_elems;
2178   if (!matchPattern(tf_one_hot_op.depth(), m_Constant(&depth_elems)))
2179     return failure();
2180   int32_t depth = depth_elems.getValues<IntegerAttr>()[0].getInt();
2181 
2182   IntegerAttr axisAttr = tf_one_hot_op.axisAttr();
2183   int32_t axis = axisAttr.getInt();
2184 
2185   llvm::Optional<Value> result = convertOneHotOp(
2186       rewriter, op, tf_one_hot_op.getResult(), tf_one_hot_op.indices(),
2187       tf_one_hot_op.on_value(), tf_one_hot_op.off_value(), depth, axis);
2188 
2189   if (!result) return failure();
2190 
2191   rewriter.replaceOp(op, {result.getValue()});
2192 
2193   return success();
2194 }
2195 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2196 LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite(
2197     Operation* op, PatternRewriter& rewriter) const {
2198   auto tf_batch_matmul_op = cast<TF::BatchMatMulV2Op>(op);
2199 
2200   RankedTensorType x_type =
2201       tf_batch_matmul_op.x().getType().dyn_cast<RankedTensorType>();
2202   RankedTensorType y_type =
2203       tf_batch_matmul_op.y().getType().dyn_cast<RankedTensorType>();
2204   RankedTensorType output_type =
2205       tf_batch_matmul_op.getResult().getType().dyn_cast<RankedTensorType>();
2206 
2207   if (!(x_type && y_type && output_type)) {
2208     return op->emitOpError("BatchMatMulV2: x/y/output not ranked tensors");
2209   }
2210 
2211   if (x_type.getRank() != y_type.getRank() ||
2212       x_type.getRank() != output_type.getRank()) {
2213     return op->emitOpError("BatchMatMulV2: x/y/output rank must match");
2214   }
2215 
2216   if (x_type.getRank() <= 2) {
2217     return op->emitOpError("BatchMatMulV2: x/y/output rank must > 2");
2218   }
2219 
2220   // Rank 3 batch matmul can be directly mapped to tosa.matmul trivially.
2221   if (x_type.getRank() == 3) {
2222     CreateReplaceOpAndInfer<tosa::MatMulOp>(rewriter, op, output_type,
2223                                             tf_batch_matmul_op.x(),
2224                                             tf_batch_matmul_op.y());
2225   } else {
2226     // 1. Reshape x from: (similar for y)
2227     //  [a0, a1, ... an, H, C] to [N, H, C].
2228     //  where N = a0 * a1 * ... * an.
2229     // 2. tosa.MatMul
2230     //  [N, H, C] * [N, C, W] -> [N, H, W].
2231     // 3. Reshape output from:
2232     //  [N, H, W] to [a0, a1, ... , an, H, W]
2233     int64_t rank = x_type.getRank();
2234     int64_t N = 1;
2235     for (int i = 0; i < (rank - 2); i++) {
2236       N *= x_type.getShape()[i];
2237     }
2238     int64_t H = x_type.getShape()[rank - 2];
2239     int64_t C = x_type.getShape()[rank - 1];
2240     int64_t W = y_type.getShape()[rank - 1];
2241 
2242     SmallVector<int64_t, 3> rank3_x_shape({N, H, C});
2243     SmallVector<int64_t, 3> rank3_y_shape({N, C, W});
2244     SmallVector<int64_t, 3> rank3_output_shape({N, H, W});
2245 
2246     RankedTensorType rank3_x_type =
2247         RankedTensorType::get(rank3_x_shape, x_type.getElementType());
2248     RankedTensorType rank3_y_type =
2249         RankedTensorType::get(rank3_y_shape, y_type.getElementType());
2250     RankedTensorType rank3_output_type =
2251         RankedTensorType::get(rank3_output_shape, output_type.getElementType());
2252 
2253     auto op1_reshape_x = CreateOpAndInfer<tosa::ReshapeOp>(
2254         rewriter, op->getLoc(), rank3_x_type, tf_batch_matmul_op.x(),
2255         rewriter.getI64ArrayAttr(rank3_x_shape));
2256 
2257     auto op2_reshape_y = CreateOpAndInfer<tosa::ReshapeOp>(
2258         rewriter, op->getLoc(), rank3_y_type, tf_batch_matmul_op.y(),
2259         rewriter.getI64ArrayAttr(rank3_y_shape));
2260 
2261     auto op3_matmul_op1_op2 = CreateOpAndInfer<tosa::MatMulOp>(
2262         rewriter, op->getLoc(), rank3_output_type, op1_reshape_x.getResult(),
2263         op2_reshape_y.getResult());
2264 
2265     CreateReplaceOpAndInfer<tosa::ReshapeOp>(
2266         rewriter, op, output_type, op3_matmul_op1_op2.getResult(),
2267         rewriter.getI64ArrayAttr(output_type.getShape()));
2268   }
2269   return success();
2270 }
2271 
runOnOperation()2272 void LegalizeTF::runOnOperation() {
2273   auto* ctx = &getContext();
2274   RewritePatternSet patterns(ctx);
2275   auto func = getOperation();
2276   populateLegalizeTFPatterns(ctx, patterns);
2277 
2278   if (ApplyPatternsWithShapeResolution(func, std::move(patterns)).failed()) {
2279     signalPassFailure();
2280   }
2281 }
2282 
2283 }  // anonymous namespace
2284 
populateLegalizeTFPatterns(MLIRContext * ctx,RewritePatternSet & patterns)2285 void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) {
2286   // Add the generated patterns to the list.
2287   populateWithGenerated(patterns);
2288   patterns.add<ConvertTFMatMulOp>(ctx);
2289   patterns.add<ConvertTFReluOp>(ctx);
2290   patterns.add<ConvertTFRelu6Op>(ctx);
2291   patterns.add<ConvertTFEqualOp>(ctx);
2292   patterns.add<ConvertTFNotEqualOp>(ctx);
2293   patterns.add<ConvertTFGreaterOp>(ctx);
2294   patterns.add<ConvertTFGreaterEqualOp>(ctx);
2295   patterns.add<ConvertTFAddOp>(ctx);
2296   patterns.add<ConvertTFAddV2Op>(ctx);
2297   patterns.add<ConvertTFAddNOp>(ctx);
2298   patterns.add<ConvertTFSubOp>(ctx);
2299   patterns.add<ConvertTFMulOp>(ctx);
2300   patterns.add<ConvertTFSquareOp>(ctx);
2301   patterns.add<ConvertTFSquaredDifferenceOp>(ctx);
2302   patterns.add<ConvertTFRoundOp>(ctx);
2303   patterns.add<ConvertTFFloorDivOp>(ctx);
2304   patterns.add<ConvertTFFloorModOp>(ctx);
2305   patterns.add<ConvertTFAssertOp>(ctx);
2306   patterns.add<ConvertTFMaximumOp>(ctx);
2307   patterns.add<ConvertTFMinimumOp>(ctx);
2308   patterns.add<ConvertTFRealDivOp>(ctx);
2309   patterns.add<ConvertTFArgMaxOp>(ctx);
2310   patterns.add<ConvertTFAvgPoolOp>(ctx);
2311   patterns.add<ConvertTFMaxPoolOp>(ctx);
2312   patterns.add<ConvertTFConcatV2Op>(ctx);
2313   patterns.add<ConvertTFReshapeOp>(ctx);
2314   patterns.add<ConvertTFRankOp>(ctx);
2315   patterns.add<ConvertTFShapeOp>(ctx);
2316   patterns.add<ConvertTFExpandDimsOp>(ctx);
2317   patterns.add<ConvertTFSqueezeOp>(ctx);
2318   patterns.add<ConvertTFFillOp>(ctx);
2319   patterns.add<ConvertTFConv2DOp>(ctx);
2320   patterns.add<ConvertTFDepthwiseConv2dNativeOp>(ctx);
2321   patterns.add<ConvertTFConv2DBackpropInputOp>(ctx);
2322   patterns.add<ConvertTFEluOp>(ctx);
2323   patterns.add<ConvertTFSoftmaxOp>(ctx);
2324   patterns.add<ConvertTFLogSoftmaxOp>(ctx);
2325   patterns.add<ConvertTFAllOp>(ctx);
2326   patterns.add<ConvertTFAnyOp>(ctx);
2327   patterns.add<ConvertTFMaxOp>(ctx);
2328   patterns.add<ConvertTFMinOp>(ctx);
2329   patterns.add<ConvertTFMeanOp>(ctx);
2330   patterns.add<ConvertTFProdOp>(ctx);
2331   patterns.add<ConvertTFSumOp>(ctx);
2332   patterns.add<ConvertTFFusedBatchNormOp>(ctx);
2333   patterns.add<ConvertTFFusedBatchNormV3Op>(ctx);
2334   patterns.add<ConvertTFBiasAddOp>(ctx);
2335   patterns.add<ConvertTFSplitOp>(ctx);
2336   patterns.add<ConvertTFSplitVOp>(ctx);
2337   patterns.add<ConvertTFPackOp>(ctx);
2338   patterns.add<ConvertTFUnpackOp>(ctx);
2339   patterns.add<ConvertTFTransposeOp>(ctx);
2340   patterns.add<ConvertTFTileOp>(ctx);
2341   patterns.add<ConvertTFSliceOp>(ctx);
2342   patterns.add<ConvertTFStridedSliceOp>(ctx);
2343   patterns.add<ConvertTFLessOp>(ctx);
2344   patterns.add<ConvertTFLessEqualOp>(ctx);
2345   patterns.add<ConvertTFPadOp>(ctx);
2346   patterns.add<ConvertTFResizeBilinearOp>(ctx);
2347   patterns.add<ConvertTFResizeNearestNeighborOp>(ctx);
2348   patterns.add<ConvertTFGatherOp>(ctx);
2349   patterns.add<ConvertTFGatherV2Op>(ctx);
2350   patterns.add<ConvertTFGatherNdOp>(ctx);
2351   patterns.add<ConvertTFSelectV2Op>(ctx);
2352   patterns.add<ConvertTFSpaceToDepthOp>(ctx);
2353   patterns.add<ConvertTFDepthToSpaceOp>(ctx);
2354   patterns.add<ConvertTFSpaceToBatchNDOp>(ctx);
2355   patterns.add<ConvertTFBatchToSpaceNDOp>(ctx);
2356   patterns.add<ConvertTFZerosLikeOp>(ctx);
2357   patterns.add<ConvertTFSigmoidOp>(ctx);
2358   patterns.add<ConvertTFTanhOp>(ctx);
2359   patterns.add<ConvertTFLeakyReluOp>(ctx);
2360   patterns.add<ConvertTFNegOp>(ctx);
2361   patterns.add<ConvertTFStopGradientOp>(ctx);
2362   patterns.add<ConvertTFReverseV2Op>(ctx);
2363   patterns.add<ConvertTFFakeQuantWithMinMaxArgsOp>(ctx);
2364   patterns.add<ConvertTFFakeQuantWithMinMaxVarsOp>(ctx);
2365   patterns.add<ConvertTFLeftShiftOp>(ctx);
2366   patterns.add<ConvertTFRightShiftOp>(ctx);
2367   patterns.add<ConvertTFOneHotOp>(ctx);
2368   patterns.add<ConvertTFBatchMatMulV2Op>(ctx);
2369 }
2370 
2371 // Creates an instance of the TensorFlow dialect LegalizeTF pass.
createLegalizeTFPass()2372 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTFPass() {
2373   return std::make_unique<LegalizeTF>();
2374 }
2375 
2376 }  // namespace tosa
2377 
2378 }  // namespace mlir
2379