1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Legalize TensorFlow to TOSA
17
18 #include <climits>
19 #include <cstddef>
20 #include <cstdint>
21 #include <iterator>
22 #include <numeric>
23
24 #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
25 #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
26 #include "mlir/Support/LLVM.h" // from @llvm-project
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h"
31 #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h"
32 #include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
33
34 #define PASS_NAME "tosa-legalize-tf"
35 #define DEBUG_TYPE PASS_NAME
36
37 namespace mlir {
38 namespace tosa {
39 namespace {
40
41 // Performs lowering to TOSA dialect
42 class LegalizeTF : public TosaLegalizeTFPassBase<LegalizeTF> {
43 public:
LegalizeTF()44 explicit LegalizeTF() {}
45 void runOnOperation() override;
46 };
47
48 // All the Pat<> lowering mappings.
49 #include "tensorflow/compiler/mlir/tosa/transforms/tf_legalize_patterns.inc"
50
51 #define DECL_CONVERT_OP(tf_op) \
52 struct ConvertTF##tf_op##Op : public RewritePattern { \
53 explicit ConvertTF##tf_op##Op(MLIRContext* context) \
54 : RewritePattern(TF::tf_op##Op::getOperationName(), 1, context) {} \
55 LogicalResult matchAndRewrite(Operation* op, \
56 PatternRewriter& rewriter) const override; \
57 }
58
59 // All the explcitly implemented complex lowerings.
60 DECL_CONVERT_OP(MatMul);
61 DECL_CONVERT_OP(Relu);
62 DECL_CONVERT_OP(Relu6);
63 DECL_CONVERT_OP(Equal);
64 DECL_CONVERT_OP(NotEqual);
65 DECL_CONVERT_OP(Greater);
66 DECL_CONVERT_OP(GreaterEqual);
67 DECL_CONVERT_OP(Add);
68 DECL_CONVERT_OP(AddV2);
69 DECL_CONVERT_OP(AddN);
70 DECL_CONVERT_OP(Sub);
71 DECL_CONVERT_OP(Mul);
72 DECL_CONVERT_OP(Square);
73 DECL_CONVERT_OP(SquaredDifference);
74 DECL_CONVERT_OP(Round);
75 DECL_CONVERT_OP(FloorDiv);
76 DECL_CONVERT_OP(FloorMod);
77 DECL_CONVERT_OP(Assert);
78 DECL_CONVERT_OP(Maximum);
79 DECL_CONVERT_OP(Minimum);
80 DECL_CONVERT_OP(RealDiv);
81 DECL_CONVERT_OP(ArgMax);
82 DECL_CONVERT_OP(AvgPool);
83 DECL_CONVERT_OP(MaxPool);
84 DECL_CONVERT_OP(ConcatV2);
85 DECL_CONVERT_OP(Reshape);
86 DECL_CONVERT_OP(Rank);
87 DECL_CONVERT_OP(Shape);
88 DECL_CONVERT_OP(ExpandDims);
89 DECL_CONVERT_OP(Squeeze);
90 DECL_CONVERT_OP(Fill);
91 DECL_CONVERT_OP(Conv2D);
92 DECL_CONVERT_OP(DepthwiseConv2dNative);
93 DECL_CONVERT_OP(Conv2DBackpropInput);
94 DECL_CONVERT_OP(Elu);
95 DECL_CONVERT_OP(Softmax);
96 DECL_CONVERT_OP(LogSoftmax);
97 DECL_CONVERT_OP(All);
98 DECL_CONVERT_OP(Any);
99 DECL_CONVERT_OP(Max);
100 DECL_CONVERT_OP(Min);
101 DECL_CONVERT_OP(Mean);
102 DECL_CONVERT_OP(Prod);
103 DECL_CONVERT_OP(Sum);
104 DECL_CONVERT_OP(FusedBatchNorm);
105 DECL_CONVERT_OP(FusedBatchNormV3);
106 DECL_CONVERT_OP(BiasAdd);
107 DECL_CONVERT_OP(Split);
108 DECL_CONVERT_OP(SplitV);
109 DECL_CONVERT_OP(Pack);
110 DECL_CONVERT_OP(Unpack);
111 DECL_CONVERT_OP(Transpose);
112 DECL_CONVERT_OP(Tile);
113 DECL_CONVERT_OP(Slice);
114 DECL_CONVERT_OP(StridedSlice);
115 DECL_CONVERT_OP(Less);
116 DECL_CONVERT_OP(LessEqual);
117 DECL_CONVERT_OP(Pad);
118 DECL_CONVERT_OP(ResizeBilinear);
119 DECL_CONVERT_OP(ResizeNearestNeighbor);
120 DECL_CONVERT_OP(Gather);
121 DECL_CONVERT_OP(GatherV2);
122 DECL_CONVERT_OP(GatherNd);
123 DECL_CONVERT_OP(SelectV2);
124 DECL_CONVERT_OP(SpaceToDepth);
125 DECL_CONVERT_OP(DepthToSpace);
126 DECL_CONVERT_OP(SpaceToBatchND);
127 DECL_CONVERT_OP(BatchToSpaceND);
128 DECL_CONVERT_OP(ZerosLike);
129 DECL_CONVERT_OP(Sigmoid);
130 DECL_CONVERT_OP(Tanh);
131 DECL_CONVERT_OP(LeakyRelu);
132 DECL_CONVERT_OP(Neg);
133 DECL_CONVERT_OP(StopGradient);
134 DECL_CONVERT_OP(ReverseV2);
135 DECL_CONVERT_OP(FakeQuantWithMinMaxArgs);
136 DECL_CONVERT_OP(FakeQuantWithMinMaxVars);
137 DECL_CONVERT_OP(LeftShift);
138 DECL_CONVERT_OP(RightShift);
139 DECL_CONVERT_OP(OneHot);
140 DECL_CONVERT_OP(BatchMatMulV2);
141 #undef DECL_CONVERT_OP
142
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const143 LogicalResult ConvertTFReluOp::matchAndRewrite(
144 Operation* op, PatternRewriter& rewriter) const {
145 auto tf_relu_op = cast<TF::ReluOp>(op);
146
147 TensorType output_type =
148 tf_relu_op.getResult().getType().dyn_cast<TensorType>();
149 // Not a tensor output
150 if (!output_type) return failure();
151
152 CreateReplaceOpAndInfer<tosa::ClampOp>(
153 rewriter, op, output_type, tf_relu_op.features(),
154 rewriter.getI64IntegerAttr(0),
155 rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
156 rewriter.getF32FloatAttr(0.0f),
157 rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
158 return success();
159 }
160
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const161 LogicalResult ConvertTFRelu6Op::matchAndRewrite(
162 Operation* op, PatternRewriter& rewriter) const {
163 auto tf_relu6_op = cast<TF::Relu6Op>(op);
164
165 TensorType output_type =
166 tf_relu6_op.getResult().getType().dyn_cast<TensorType>();
167 // Not a tensor output
168 if (!output_type) return failure();
169
170 CreateReplaceOpAndInfer<tosa::ClampOp>(
171 rewriter, op, output_type, tf_relu6_op.features(),
172 rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(6),
173 rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f));
174 return success();
175 }
176
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const177 LogicalResult ConvertTFEqualOp::matchAndRewrite(
178 Operation* op, PatternRewriter& rewriter) const {
179 auto tf_equal_op = cast<TF::EqualOp>(op);
180
181 TensorType output_type =
182 tf_equal_op.getResult().getType().dyn_cast<TensorType>();
183 // Not a tensor output
184 if (!output_type) return failure();
185
186 CreateReplaceOpAndInfer<tosa::EqualOp>(rewriter, op, output_type,
187 tf_equal_op.x(), tf_equal_op.y());
188 return success();
189 }
190
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const191 LogicalResult ConvertTFNotEqualOp::matchAndRewrite(
192 Operation* op, PatternRewriter& rewriter) const {
193 auto tf_not_equal_op = cast<TF::NotEqualOp>(op);
194
195 TensorType output_type =
196 tf_not_equal_op.getResult().getType().dyn_cast<TensorType>();
197 // Not a tensor output
198 if (!output_type) return failure();
199
200 auto op1_equal_in =
201 CreateOpAndInfer<tosa::EqualOp>(rewriter, op->getLoc(), output_type,
202 tf_not_equal_op.x(), tf_not_equal_op.y());
203
204 auto op2_not_op1 = CreateOpAndInfer<tosa::LogicalNotOp>(
205 rewriter, op->getLoc(), output_type, op1_equal_in.getResult());
206
207 rewriter.replaceOp(op, {op2_not_op1.getResult()});
208
209 return success();
210 }
211
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const212 LogicalResult ConvertTFGreaterOp::matchAndRewrite(
213 Operation* op, PatternRewriter& rewriter) const {
214 auto tf_greater_op = cast<TF::GreaterOp>(op);
215
216 TensorType output_type =
217 tf_greater_op.getResult().getType().dyn_cast<TensorType>();
218 // Not a tensor output
219 if (!output_type) return failure();
220
221 CreateReplaceOpAndInfer<tosa::GreaterOp>(
222 rewriter, op, output_type, tf_greater_op.x(), tf_greater_op.y());
223 return success();
224 }
225
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const226 LogicalResult ConvertTFGreaterEqualOp::matchAndRewrite(
227 Operation* op, PatternRewriter& rewriter) const {
228 auto tf_greater_equal_op = cast<TF::GreaterEqualOp>(op);
229
230 TensorType output_type =
231 tf_greater_equal_op.getResult().getType().dyn_cast<TensorType>();
232 // Not a tensor output
233 if (!output_type) return failure();
234
235 CreateReplaceOpAndInfer<tosa::GreaterEqualOp>(rewriter, op, output_type,
236 tf_greater_equal_op.x(),
237 tf_greater_equal_op.y());
238 return success();
239 }
240
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const241 LogicalResult ConvertTFAddOp::matchAndRewrite(Operation* op,
242 PatternRewriter& rewriter) const {
243 auto tf_add_op = cast<TF::AddOp>(op);
244
245 TensorType output_type =
246 tf_add_op.getResult().getType().dyn_cast<TensorType>();
247 // Not a tensor output
248 if (!output_type) return failure();
249
250 CreateReplaceOpAndInfer<tosa::AddOp>(rewriter, op, output_type, tf_add_op.x(),
251 tf_add_op.y());
252 return success();
253 }
254
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const255 LogicalResult ConvertTFAddV2Op::matchAndRewrite(
256 Operation* op, PatternRewriter& rewriter) const {
257 auto tf_addv2_op = cast<TF::AddV2Op>(op);
258
259 TensorType output_type =
260 tf_addv2_op.getResult().getType().dyn_cast<TensorType>();
261 // Not a tensor output
262 if (!output_type) return failure();
263
264 CreateReplaceOpAndInfer<tosa::AddOp>(rewriter, op, output_type,
265 tf_addv2_op.x(), tf_addv2_op.y());
266 return success();
267 }
268
269 // AddN is commutative
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const270 LogicalResult ConvertTFAddNOp::matchAndRewrite(
271 Operation* op, PatternRewriter& rewriter) const {
272 auto tf_addn_op = cast<TF::AddNOp>(op);
273
274 TensorType output_type =
275 tf_addn_op.getResult().getType().dyn_cast<TensorType>();
276 // Not a tensor output
277 if (!output_type) return failure();
278
279 SmallVector<Value> inputs(tf_addn_op.inputs());
280
281 assert(inputs.size() >= 2);
282
283 auto newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(),
284 output_type, inputs[0], inputs[1]);
285 for (int i = 2; i < inputs.size(); i++) {
286 newOp = CreateOpAndInfer<tosa::AddOp>(rewriter, op->getLoc(), output_type,
287 inputs[i], newOp.getResult());
288 }
289
290 rewriter.replaceOp(op, {newOp.getResult()});
291
292 return success();
293 }
294
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const295 LogicalResult ConvertTFSubOp::matchAndRewrite(Operation* op,
296 PatternRewriter& rewriter) const {
297 auto tf_sub_op = cast<TF::SubOp>(op);
298
299 TensorType output_type =
300 tf_sub_op.getResult().getType().dyn_cast<TensorType>();
301 // Not a tensor output
302 if (!output_type) return failure();
303
304 CreateReplaceOpAndInfer<tosa::SubOp>(rewriter, op, output_type, tf_sub_op.x(),
305 tf_sub_op.y());
306 return success();
307 }
308
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const309 LogicalResult ConvertTFMulOp::matchAndRewrite(Operation* op,
310 PatternRewriter& rewriter) const {
311 auto tf_mul_op = cast<TF::MulOp>(op);
312
313 llvm::Optional<Value> result = convertMultiplyOp(
314 rewriter, op, tf_mul_op.getResult(), tf_mul_op.x(), tf_mul_op.y());
315
316 if (!result) return failure();
317
318 rewriter.replaceOp(op, {result.getValue()});
319 return success();
320 }
321
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const322 LogicalResult ConvertTFSquareOp::matchAndRewrite(
323 Operation* op, PatternRewriter& rewriter) const {
324 auto tf_square_op = cast<TF::SquareOp>(op);
325
326 llvm::Optional<Value> result =
327 convertMultiplyOp(rewriter, op, tf_square_op.getResult(),
328 tf_square_op.x(), tf_square_op.x());
329
330 if (!result) return failure();
331
332 rewriter.replaceOp(op, {result.getValue()});
333 return success();
334 }
335
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const336 LogicalResult ConvertTFSquaredDifferenceOp::matchAndRewrite(
337 Operation* op, PatternRewriter& rewriter) const {
338 auto tf_squared_op = cast<TF::SquaredDifferenceOp>(op);
339
340 llvm::Optional<Value> result =
341 convertSquaredDifferenceOp(rewriter, op, tf_squared_op.getResult(),
342 tf_squared_op.x(), tf_squared_op.y());
343
344 if (!result) return failure();
345
346 rewriter.replaceOp(op, {result.getValue()});
347 return success();
348 }
349
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const350 LogicalResult ConvertTFRoundOp::matchAndRewrite(
351 Operation* op, PatternRewriter& rewriter) const {
352 auto tf_round_op = cast<TF::RoundOp>(op);
353
354 TensorType input_type = tf_round_op.x().getType().dyn_cast<TensorType>();
355 if (!input_type) {
356 return op->emitOpError("Round: input not tensor type");
357 }
358
359 if (input_type.getElementType().isa<FloatType>()) {
360 llvm::Optional<Value> result =
361 convertRoundOp(rewriter, op, tf_round_op.getResult(), tf_round_op.x());
362
363 if (!result) return failure();
364
365 rewriter.replaceOp(op, {result.getValue()});
366 return success();
367
368 } else {
369 tf_round_op.replaceAllUsesWith(tf_round_op.x());
370 return success();
371 }
372 }
373
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const374 LogicalResult ConvertTFFloorDivOp::matchAndRewrite(
375 Operation* op, PatternRewriter& rewriter) const {
376 auto tf_floordiv_op = cast<TF::FloorDivOp>(op);
377
378 llvm::Optional<Value> result =
379 convertFloorDivOp(rewriter, op, tf_floordiv_op.getResult(),
380 tf_floordiv_op.x(), tf_floordiv_op.y());
381
382 if (!result) return failure();
383
384 rewriter.replaceOp(op, {result.getValue()});
385
386 return success();
387 }
388
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const389 LogicalResult ConvertTFFloorModOp::matchAndRewrite(
390 Operation* op, PatternRewriter& rewriter) const {
391 auto tf_floormod_op = cast<TF::FloorModOp>(op);
392
393 llvm::Optional<Value> result =
394 convertFloorModOp(rewriter, op, tf_floormod_op.getResult(),
395 tf_floormod_op.x(), tf_floormod_op.y());
396
397 if (!result) return failure();
398
399 rewriter.replaceOp(op, {result.getValue()});
400
401 return success();
402 }
403
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const404 LogicalResult ConvertTFAssertOp::matchAndRewrite(
405 Operation* op, PatternRewriter& rewriter) const {
406 op->dropAllReferences();
407 op->erase();
408 return success();
409 }
410
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const411 LogicalResult ConvertTFMaximumOp::matchAndRewrite(
412 Operation* op, PatternRewriter& rewriter) const {
413 auto tf_maximum_op = cast<TF::MaximumOp>(op);
414
415 TensorType output_type =
416 tf_maximum_op.getResult().getType().dyn_cast<TensorType>();
417 // Not a tensor output
418 if (!output_type) return failure();
419
420 CreateReplaceOpAndInfer<tosa::MaximumOp>(
421 rewriter, op, output_type, tf_maximum_op.x(), tf_maximum_op.y());
422 return success();
423 }
424
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const425 LogicalResult ConvertTFMinimumOp::matchAndRewrite(
426 Operation* op, PatternRewriter& rewriter) const {
427 auto tf_minimum_op = cast<TF::MinimumOp>(op);
428
429 TensorType output_type =
430 tf_minimum_op.getResult().getType().dyn_cast<TensorType>();
431 // Not a tensor output
432 if (!output_type) return failure();
433
434 CreateReplaceOpAndInfer<tosa::MinimumOp>(
435 rewriter, op, output_type, tf_minimum_op.x(), tf_minimum_op.y());
436 return success();
437 }
438
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const439 LogicalResult ConvertTFRealDivOp::matchAndRewrite(
440 Operation* op, PatternRewriter& rewriter) const {
441 auto tf_div_op = cast<TF::RealDivOp>(op);
442
443 TensorType y_type = tf_div_op.y().getType().dyn_cast<TensorType>();
444 TensorType output_type =
445 tf_div_op.getResult().getType().dyn_cast<TensorType>();
446 // Not a tensor output
447 if (!output_type || !y_type) return failure();
448
449 Type element_type = output_type.getElementType();
450
451 if (element_type.isa<IntegerType>()) {
452 CreateReplaceOpAndInfer<tosa::DivOp>(rewriter, op, output_type,
453 tf_div_op.x(), tf_div_op.y());
454 return success();
455 }
456
457 auto reciprocal_op = CreateOpAndInfer<tosa::ReciprocalOp>(
458 rewriter, op->getLoc(), tf_div_op.y().getType(), tf_div_op.y());
459
460 auto mul_op = CreateOpAndInfer<tosa::MulOp>(rewriter, op->getLoc(),
461 output_type, tf_div_op.x(),
462 reciprocal_op.getResult(), 0);
463 rewriter.replaceOp(op, {mul_op.getResult()});
464
465 return success();
466 }
467
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const468 LogicalResult ConvertTFArgMaxOp::matchAndRewrite(
469 Operation* op, PatternRewriter& rewriter) const {
470 auto tf_argmax_op = cast<TF::ArgMaxOp>(op);
471
472 TensorType input_type = tf_argmax_op.input().getType().dyn_cast<TensorType>();
473 TensorType output_type =
474 tf_argmax_op.getResult().getType().dyn_cast<TensorType>();
475 // Not a tensor output
476 if (!output_type || !input_type) return failure();
477
478 ElementsAttr axis_elems;
479 if (!matchPattern(tf_argmax_op.dimension(), m_Constant(&axis_elems)))
480 return failure();
481
482 int32_t axis = axis_elems.getValues<IntegerAttr>()[0].getInt();
483 if (axis < 0) {
484 axis += input_type.getRank();
485 }
486
487 if (axis < 0 || axis >= input_type.getRank()) {
488 return op->emitOpError("TFArgMax: invalid axis value");
489 }
490
491 IntegerAttr axis_attr = rewriter.getI64IntegerAttr(axis);
492
493 CreateReplaceOpAndInfer<tosa::ArgMaxOp>(rewriter, op, output_type,
494 tf_argmax_op.input(), axis_attr);
495
496 return success();
497 }
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const498 LogicalResult ConvertTFAvgPoolOp::matchAndRewrite(
499 Operation* op, PatternRewriter& rewriter) const {
500 auto tf_avgpool_op = cast<TF::AvgPoolOp>(op);
501
502 RankedTensorType input_type =
503 tf_avgpool_op.value().getType().dyn_cast<RankedTensorType>();
504 RankedTensorType output_type =
505 tf_avgpool_op.getResult().getType().dyn_cast<RankedTensorType>();
506 // Not a ranked tensor output
507 if (!input_type || !output_type) return failure();
508
509 auto tmpAttr = tf_avgpool_op.data_formatAttr();
510 if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
511
512 ArrayAttr pad;
513 ArrayAttr stride;
514 ArrayAttr kernel;
515 {
516 auto tmpAttr = tf_avgpool_op.strides();
517 if (!tmpAttr) {
518 stride = rewriter.getI64ArrayAttr({1, 1});
519 } else {
520 // Note: hardcoded to NHWC for now
521 int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
522 int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
523 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
524 }
525 }
526 {
527 auto tmpAttr = tf_avgpool_op.ksize();
528 if (!tmpAttr) {
529 kernel = rewriter.getI64ArrayAttr({1, 1});
530 } else {
531 // Note: hardcoded to NHWC for now
532 int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
533 int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
534 kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
535 }
536 }
537 {
538 tensorflow::Padding tf_pad;
539 if (!GetPaddingFromString(tf_avgpool_op.padding().str(), &tf_pad).ok())
540 return failure();
541
542 ArrayAttr dilation =
543 rewriter.getI64ArrayAttr({1, 1}); // Pooling has no non-unit dilation
544
545 SmallVector<int64_t, 4> i64array;
546
547 for (auto& elem : tf_avgpool_op.ksize()) {
548 int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
549 i64array.emplace_back(value);
550 }
551
552 RankedTensorType filter_type = RankedTensorType::get(
553 llvm::makeArrayRef(i64array), rewriter.getIntegerType(64));
554
555 if (!getPaddingValuesFromPadType(
556 tf_pad,
557 tensorflow::FORMAT_NHWC, // TFLite only supports this
558 1, // tensorflow::FORMAT_OHWI,
559 input_type, filter_type, stride, dilation, rewriter, pad))
560 return failure();
561 }
562
563 CreateReplaceOpAndInfer<tosa::AvgPool2dOp>(
564 rewriter, op, output_type, tf_avgpool_op.value(), kernel, stride, pad);
565 return success();
566 }
567
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const568 LogicalResult ConvertTFMaxPoolOp::matchAndRewrite(
569 Operation* op, PatternRewriter& rewriter) const {
570 auto tf_maxpool_op = cast<TF::MaxPoolOp>(op);
571
572 RankedTensorType input_type =
573 tf_maxpool_op.input().getType().dyn_cast<RankedTensorType>();
574 RankedTensorType output_type =
575 tf_maxpool_op.getResult().getType().dyn_cast<RankedTensorType>();
576 // Not a ranked tensor output
577 if (!input_type || !output_type) return failure();
578
579 auto tmpAttr = tf_maxpool_op.data_formatAttr();
580 if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
581
582 ArrayAttr pad;
583 ArrayAttr stride;
584 ArrayAttr kernel;
585 {
586 auto tmpAttr = tf_maxpool_op.strides();
587 if (!tmpAttr) {
588 stride = rewriter.getI64ArrayAttr({1, 1});
589 } else {
590 // Note: hardcoded to NHWC for now
591 int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
592 int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
593 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
594 }
595 }
596 {
597 auto tmpAttr = tf_maxpool_op.ksize();
598 if (!tmpAttr) {
599 kernel = rewriter.getI64ArrayAttr({1, 1});
600 } else {
601 // Note: hardcoded to NHWC for now
602 int64_t kernel_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
603 int64_t kernel_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
604 kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w});
605 }
606 }
607 {
608 tensorflow::Padding tf_pad;
609 if (!GetPaddingFromString(tf_maxpool_op.padding().str(), &tf_pad).ok())
610 return failure();
611
612 // Pooling has no non-unit dilation
613 ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1});
614
615 SmallVector<int64_t, 4> i64array;
616
617 for (auto& elem : tf_maxpool_op.ksize()) {
618 int64_t value = elem.dyn_cast<IntegerAttr>().getInt();
619 i64array.emplace_back(value);
620 }
621
622 RankedTensorType filter_type = RankedTensorType::get(
623 llvm::makeArrayRef(i64array), rewriter.getIntegerType(64));
624
625 if (!getPaddingValuesFromPadType(
626 tf_pad,
627 tensorflow::FORMAT_NHWC, // TFLite only supports this
628 1, // tensorflow::FORMAT_OHWI,
629 input_type, filter_type, stride, dilation, rewriter, pad))
630 return failure();
631 }
632
633 CreateReplaceOpAndInfer<tosa::MaxPool2dOp>(
634 rewriter, op, output_type, tf_maxpool_op.input(), kernel, stride, pad);
635 return success();
636 }
637
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const638 LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
639 Operation* op, PatternRewriter& rewriter) const {
640 auto tf_concatv2_op = cast<TF::ConcatV2Op>(op);
641 SmallVector<Value> values(tf_concatv2_op.values());
642
643 ElementsAttr axis_elems;
644 if (!matchPattern(tf_concatv2_op.axis(), m_Constant(&axis_elems)))
645 return failure();
646
647 int32_t axis = axis_elems.getValues<IntegerAttr>()[0].getInt();
648
649 llvm::Optional<Value> result =
650 convertConcatV2Op(rewriter, op, tf_concatv2_op.getResult(), values, axis);
651
652 if (!result) return failure();
653
654 rewriter.replaceOp(op, {result.getValue()});
655
656 return success();
657 }
658
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const659 LogicalResult ConvertTFReshapeOp::matchAndRewrite(
660 Operation* op, PatternRewriter& rewriter) const {
661 auto tf_reshape_op = cast<TF::ReshapeOp>(op);
662
663 RankedTensorType output_type =
664 tf_reshape_op.getResult().getType().dyn_cast<RankedTensorType>();
665 // Not a ranked tensor output
666 if (!output_type) return failure();
667
668 // Regular way to match tensor as element attribute doesn't always work
669 // use output_type.getShape() which is more stable
670 SmallVector<int64_t> shape_vals;
671 for (int i = 0; i < output_type.getShape().size(); i++) {
672 shape_vals.push_back(output_type.getShape()[i]);
673 }
674 ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals);
675
676 CreateReplaceOpAndInfer<tosa::ReshapeOp>(rewriter, op, output_type,
677 tf_reshape_op.tensor(), shape_attr);
678 return success();
679 }
680
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const681 LogicalResult ConvertTFRankOp::matchAndRewrite(
682 Operation* op, PatternRewriter& rewriter) const {
683 auto tf_rank_op = cast<TF::RankOp>(op);
684
685 RankedTensorType input_type =
686 tf_rank_op.input().getType().dyn_cast<RankedTensorType>();
687 if (!input_type) return failure();
688
689 int32_t rank = input_type.getRank();
690
691 RankedTensorType rank_type =
692 RankedTensorType::get({1}, rewriter.getIntegerType(32));
693 auto rank_attr = DenseElementsAttr::get(rank_type, {rank});
694 auto rank_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
695 rank_type, rank_attr);
696
697 rewriter.replaceOp(op, {rank_const.getResult()});
698
699 return success();
700 }
701
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const702 LogicalResult ConvertTFShapeOp::matchAndRewrite(
703 Operation* op, PatternRewriter& rewriter) const {
704 auto tf_shape_op = cast<TF::ShapeOp>(op);
705
706 RankedTensorType output_type =
707 tf_shape_op.getResult().getType().dyn_cast<RankedTensorType>();
708 // Not a ranked tensor output
709 if (!output_type) return failure();
710
711 RankedTensorType input_type =
712 tf_shape_op.input().getType().dyn_cast<RankedTensorType>();
713 if (!input_type) return failure();
714
715 auto input_shape = input_type.getShape();
716
717 SmallVector<int32_t> shape_arr;
718 for (int i = 0; i < input_shape.size(); i++) {
719 shape_arr.emplace_back(input_shape[i]);
720 }
721
722 RankedTensorType shape_type = RankedTensorType::get(
723 {static_cast<int32_t>(shape_arr.size())}, rewriter.getIntegerType(32));
724 auto shape_attr =
725 DenseElementsAttr::get(shape_type, llvm::makeArrayRef(shape_arr));
726 auto shape_const = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
727 shape_type, shape_attr);
728
729 rewriter.replaceOp(op, {shape_const.getResult()});
730
731 return success();
732 }
733
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const734 LogicalResult ConvertTFExpandDimsOp::matchAndRewrite(
735 Operation* op, PatternRewriter& rewriter) const {
736 auto tf_expanddims_op = cast<TF::ExpandDimsOp>(op);
737
738 llvm::Optional<Value> result =
739 convertExpandDimsOp(rewriter, op, tf_expanddims_op.getResult(),
740 tf_expanddims_op.input(), tf_expanddims_op.dim());
741
742 if (!result) return failure();
743
744 rewriter.replaceOp(op, {result.getValue()});
745
746 return success();
747 }
748
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const749 LogicalResult ConvertTFSqueezeOp::matchAndRewrite(
750 Operation* op, PatternRewriter& rewriter) const {
751 auto tf_squeeze_op = cast<TF::SqueezeOp>(op);
752
753 // Copy squeeze_dims into int32_t array
754 auto squeeze_dims_attr = tf_squeeze_op.squeeze_dimsAttr();
755 SmallVector<int32_t> squeeze_dims;
756 for (auto& squeeze_dim : squeeze_dims_attr) {
757 squeeze_dims.emplace_back(squeeze_dim.dyn_cast<IntegerAttr>().getInt());
758 }
759
760 llvm::Optional<Value> result =
761 convertSqueezeOp(rewriter, op, tf_squeeze_op.getResult(),
762 tf_squeeze_op.input(), squeeze_dims);
763
764 if (!result) return failure();
765
766 rewriter.replaceOp(op, {result.getValue()});
767
768 return success();
769 }
770
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const771 LogicalResult ConvertTFFillOp::matchAndRewrite(
772 Operation* op, PatternRewriter& rewriter) const {
773 auto tf_fill_op = cast<TF::FillOp>(op);
774
775 RankedTensorType output_type =
776 tf_fill_op.getResult().getType().dyn_cast<RankedTensorType>();
777 // Not a ranked tensor output
778 if (!output_type) return failure();
779
780 ElementsAttr dims_elems;
781 if (!matchPattern(tf_fill_op.dims(), m_Constant(&dims_elems)))
782 return failure();
783 SmallVector<int64_t> dims_vals;
784 uint32_t total_size = 1;
785 for (int i = 0; i < dims_elems.getNumElements(); i++) {
786 dims_vals.push_back(dims_elems.getValues<IntegerAttr>()[i].getInt());
787 total_size *= dims_vals[i];
788 }
789
790 ElementsAttr value_elem;
791 if (!matchPattern(tf_fill_op.value(), m_Constant(&value_elem)))
792 return failure();
793
794 RankedTensorType fill_type = RankedTensorType::get(
795 ArrayRef<int64_t>(dims_vals), value_elem.getType().getElementType());
796 DenseElementsAttr fill_attr;
797
798 // Convert to a compatible zero type
799 if (value_elem.getType().getElementType().isa<FloatType>()) {
800 SmallVector<float> fill_arr(
801 total_size,
802 value_elem.getValues<FloatAttr>()[0].getValue().convertToFloat());
803 fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
804 } else {
805 SmallVector<int32_t> fill_arr(
806 total_size,
807 value_elem.getValues<IntegerAttr>()[0].getValue().getLimitedValue());
808 fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr));
809 }
810 auto fill_const_op = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(),
811 fill_type, fill_attr);
812 rewriter.replaceOp(op, {fill_const_op.getResult()});
813
814 return success();
815 }
816
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const817 LogicalResult ConvertTFConv2DOp::matchAndRewrite(
818 Operation* op, PatternRewriter& rewriter) const {
819 auto tf_conv2d_op = cast<TF::Conv2DOp>(op);
820
821 RankedTensorType filter_type =
822 tf_conv2d_op.filter().getType().dyn_cast<RankedTensorType>();
823 RankedTensorType output_type =
824 tf_conv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
825
826 // Set up a zero attr for subsequent pattern replacement if required
827 auto bias_dim = filter_type.getShape().back();
828 RankedTensorType bias_type =
829 RankedTensorType::get({bias_dim}, filter_type.getElementType());
830 auto bias_attr = rewriter.getZeroAttr(bias_type);
831 auto bias = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), bias_type,
832 bias_attr.cast<ElementsAttr>());
833
834 llvm::Optional<Value> result = convertTFConv2DCommon(
835 rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(),
836 bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(),
837 tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(),
838 tf_conv2d_op.data_format());
839
840 if (!result) return failure();
841
842 rewriter.replaceOp(op, {result.getValue()});
843
844 return success();
845 }
846
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const847 LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite(
848 Operation* op, PatternRewriter& rewriter) const {
849 auto tf_dwconv2d_op = cast<TF::DepthwiseConv2dNativeOp>(op);
850
851 RankedTensorType input_type =
852 tf_dwconv2d_op.input().getType().dyn_cast<RankedTensorType>();
853 RankedTensorType filter_type =
854 tf_dwconv2d_op.filter().getType().dyn_cast<RankedTensorType>();
855 RankedTensorType output_type =
856 tf_dwconv2d_op.getResult().getType().dyn_cast<RankedTensorType>();
857 // Not a ranked tensor output
858 if (!input_type) return failure();
859 if (!output_type) return failure();
860
861 // Set up a zero attr for subsequent pattern replacement if required
862 if (!filter_type) {
863 return op->emitOpError("DepthwiseConv2d: filter type unranked tensor");
864 }
865
866 auto tmpAttr = tf_dwconv2d_op.data_formatAttr();
867 if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure();
868
869 ArrayAttr stride;
870 ArrayAttr dilation;
871 ArrayAttr pad;
872 {
873 auto tmpAttr = tf_dwconv2d_op.strides();
874 if (!tmpAttr) {
875 stride = rewriter.getI64ArrayAttr({1, 1});
876 } else {
877 // Note: hardcoded to NHWC for now
878 int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
879 int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
880 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
881 }
882 }
883 {
884 auto tmpAttr = tf_dwconv2d_op.dilations();
885 if (!tmpAttr) {
886 dilation = rewriter.getI64ArrayAttr({1, 1});
887 } else {
888 // Note: hardcoded to NHWC for now
889 int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
890 int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
891 dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w});
892 }
893 }
894 {
895 tensorflow::Padding tf_pad;
896 if (!GetPaddingFromString(tf_dwconv2d_op.padding().str(), &tf_pad).ok())
897 return failure();
898
899 tensorflow::TensorFormat data_format_tf;
900 if (!FormatFromString(tf_dwconv2d_op.data_format().str(), &data_format_tf))
901 return failure();
902
903 if (tf_pad == tensorflow::Padding::EXPLICIT) {
904 pad = getPaddingValuesFromExplicitPadAttr(
905 tf_dwconv2d_op.explicit_paddings(), data_format_tf, rewriter);
906 } else {
907 if (!getPaddingValuesFromPadType(tf_pad, data_format_tf,
908 0, // tensorflow::FORMAT_HWIO
909 input_type, filter_type, stride,
910 dilation, rewriter, pad))
911 return failure();
912 }
913 }
914
915 auto filter_shape = filter_type.getShape();
916 auto bias_dim = filter_shape[2] * filter_shape[3];
917 RankedTensorType bias_type =
918 RankedTensorType::get({bias_dim}, filter_type.getElementType());
919 auto bias_attr = rewriter.getZeroAttr(bias_type);
920 auto bias = CreateOpAndInfer<tosa::ConstOp>(rewriter, op->getLoc(), bias_type,
921 bias_attr.cast<ElementsAttr>());
922
923 CreateReplaceOpAndInfer<tosa::DepthwiseConv2DOp>(
924 rewriter, op, output_type, tf_dwconv2d_op.input(),
925 tf_dwconv2d_op.filter(), bias, pad, stride, dilation);
926 return success();
927 }
928
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const929 LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite(
930 Operation* op, PatternRewriter& rewriter) const {
931 auto tf_conv_op = cast<TF::Conv2DBackpropInputOp>(op);
932
933 RankedTensorType input_type =
934 tf_conv_op.out_backprop().getType().dyn_cast<RankedTensorType>();
935 RankedTensorType filter_type =
936 tf_conv_op.filter().getType().dyn_cast<RankedTensorType>();
937 RankedTensorType output_type =
938 tf_conv_op.getResult().getType().dyn_cast<RankedTensorType>();
939 // Not a ranked tensor output
940 if (!input_type) return failure();
941 if (!filter_type) return failure();
942 if (!output_type) return failure();
943
944 // Transpose [H, W, I, O] to [O, H, W, I]
945 auto filter_shape = filter_type.getShape();
946 SmallVector<int64_t, 4> a1_transpose_dims;
947 a1_transpose_dims.push_back(filter_shape[2]);
948 a1_transpose_dims.push_back(filter_shape[0]);
949 a1_transpose_dims.push_back(filter_shape[1]);
950 a1_transpose_dims.push_back(filter_shape[3]);
951 llvm::Optional<Value> a1_filter_transpose_perm = getConstTensor<int32_t>(
952 rewriter, op, /*vec=*/{2, 0, 1, 3}, /*shape=*/{4});
953
954 if (!a1_filter_transpose_perm) return failure();
955
956 auto a1_filter_transpose_op = CreateOpAndInfer<tosa::TransposeOp>(
957 rewriter, op->getLoc(),
958 RankedTensorType::get(ArrayRef<int64_t>(a1_transpose_dims),
959 filter_type.getElementType()),
960 tf_conv_op.filter(), a1_filter_transpose_perm.getValue());
961
962 ArrayAttr stride;
963 ArrayAttr outpad;
964 ArrayAttr output_shape;
965 {
966 auto tmpAttr = tf_conv_op.strides();
967 if (!tmpAttr) {
968 stride = rewriter.getI64ArrayAttr({1, 1});
969 } else {
970 // Note: hardcoded to NHWC for now
971 int64_t stride_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
972 int64_t stride_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
973 stride = rewriter.getI64ArrayAttr({stride_h, stride_w});
974 }
975 }
976 {
977 auto tmpAttr = tf_conv_op.dilations();
978 if (tmpAttr) {
979 // Note: hardcoded to NHWC for now
980 int64_t dilation_h = tmpAttr[1].dyn_cast<IntegerAttr>().getInt();
981 int64_t dilation_w = tmpAttr[2].dyn_cast<IntegerAttr>().getInt();
982 // TOSA transpose_conv2d does not support non-unit dilation
983 if (dilation_h != 1 || dilation_w != 1) return failure();
984 }
985 }
986 {
987 tensorflow::Padding tf_pad;
988 if (!GetPaddingFromString(tf_conv_op.padding().str(), &tf_pad).ok())
989 return failure();
990
991 tensorflow::TensorFormat data_format_tf;
992 if (!FormatFromString(tf_conv_op.data_format().str(), &data_format_tf))
993 return failure();
994
995 if (tf_pad == tensorflow::Padding::EXPLICIT) {
996 outpad = getPaddingValuesFromExplicitPadAttr(
997 tf_conv_op.explicit_paddings(), data_format_tf, rewriter);
998 } else {
999 if (!getTransposeConv2dPaddingValues(tf_pad, data_format_tf,
1000 0, // tensorflow::FORMAT_HWIO,
1001 input_type, filter_type, output_type,
1002 stride, rewriter, outpad))
1003 return failure();
1004 }
1005 }
1006 {
1007 ElementsAttr output_shape_elems;
1008 // Match from input_sizes tensor first.
1009 if (matchPattern(tf_conv_op.input_sizes(),
1010 m_Constant(&output_shape_elems))) {
1011 SmallVector<int64_t> shape_vec;
1012 for (int i = 0; i < output_shape_elems.getNumElements(); i++)
1013 shape_vec.push_back(
1014 output_shape_elems.getValues<IntegerAttr>()[i].getInt());
1015 output_shape = rewriter.getI64ArrayAttr(shape_vec);
1016 } else {
1017 // Use output tensor's shape otherwise.
1018 output_shape = rewriter.getI64ArrayAttr(output_type.getShape());
1019 }
1020 }
1021
1022 int output_channel = output_type.getShape()[3];
1023 SmallVector<float> vec(output_channel, 0.0f);
1024 llvm::Optional<Value> zero_bias =
1025 getConstTensor<float>(rewriter, op, vec, {output_channel});
1026
1027 if (!zero_bias) return failure();
1028
1029 CreateReplaceOpAndInfer<tosa::TransposeConv2DOp>(
1030 rewriter, op, output_type, tf_conv_op.out_backprop(),
1031 a1_filter_transpose_op.getResult(), zero_bias.getValue(), outpad, stride,
1032 output_shape);
1033
1034 return success();
1035 }
1036
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1037 LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op,
1038 PatternRewriter& rewriter) const {
1039 auto tf_all_op = cast<TF::AllOp>(op);
1040
1041 RankedTensorType output_type =
1042 tf_all_op.getResult().getType().dyn_cast<RankedTensorType>();
1043 if (!output_type) return failure();
1044
1045 ElementsAttr axes_elems;
1046 if (!matchPattern(tf_all_op.reduction_indices(), m_Constant(&axes_elems)))
1047 return failure();
1048
1049 llvm::Optional<Value> result = convertReduceAllOp(
1050 rewriter, op, output_type, tf_all_op.input(), axes_elems);
1051
1052 if (!result) return failure();
1053
1054 rewriter.replaceOp(op, {result.getValue()});
1055
1056 return success();
1057 }
1058
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1059 LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op,
1060 PatternRewriter& rewriter) const {
1061 auto tf_any_op = cast<TF::AnyOp>(op);
1062
1063 RankedTensorType output_type =
1064 tf_any_op.getResult().getType().dyn_cast<RankedTensorType>();
1065 if (!output_type) return failure();
1066
1067 ElementsAttr axes_elems;
1068 if (!matchPattern(tf_any_op.reduction_indices(), m_Constant(&axes_elems)))
1069 return failure();
1070
1071 llvm::Optional<Value> result = convertReduceAnyOp(
1072 rewriter, op, output_type, tf_any_op.input(), axes_elems);
1073
1074 if (!result) return failure();
1075
1076 rewriter.replaceOp(op, {result.getValue()});
1077
1078 return success();
1079 }
1080
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1081 LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op,
1082 PatternRewriter& rewriter) const {
1083 auto tf_max_op = cast<TF::MaxOp>(op);
1084
1085 RankedTensorType output_type =
1086 tf_max_op.getResult().getType().dyn_cast<RankedTensorType>();
1087 if (!output_type) return failure();
1088
1089 ElementsAttr axes_elems;
1090 if (!matchPattern(tf_max_op.reduction_indices(), m_Constant(&axes_elems)))
1091 return failure();
1092
1093 llvm::Optional<Value> result = convertReduceMaxOp(
1094 rewriter, op, output_type, tf_max_op.input(), axes_elems);
1095
1096 if (!result) return failure();
1097
1098 rewriter.replaceOp(op, {result.getValue()});
1099
1100 return success();
1101 }
1102
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1103 LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op,
1104 PatternRewriter& rewriter) const {
1105 auto tf_min_op = cast<TF::MinOp>(op);
1106
1107 RankedTensorType output_type =
1108 tf_min_op.getResult().getType().dyn_cast<RankedTensorType>();
1109 if (!output_type) return failure();
1110
1111 ElementsAttr axes_elems;
1112 if (!matchPattern(tf_min_op.reduction_indices(), m_Constant(&axes_elems)))
1113 return failure();
1114
1115 llvm::Optional<Value> result = convertReduceMinOp(
1116 rewriter, op, output_type, tf_min_op.input(), axes_elems);
1117
1118 if (!result) return failure();
1119
1120 rewriter.replaceOp(op, {result.getValue()});
1121
1122 return success();
1123 }
1124
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1125 LogicalResult ConvertTFMeanOp::matchAndRewrite(
1126 Operation* op, PatternRewriter& rewriter) const {
1127 auto tf_mean_op = cast<TF::MeanOp>(op);
1128
1129 RankedTensorType output_type =
1130 tf_mean_op.getResult().getType().dyn_cast<RankedTensorType>();
1131 if (!output_type) return failure();
1132
1133 ElementsAttr axes_elems;
1134 if (!matchPattern(tf_mean_op.reduction_indices(), m_Constant(&axes_elems)))
1135 return failure();
1136
1137 llvm::Optional<Value> result = convertReduceMeanOp(
1138 rewriter, op, output_type, tf_mean_op.input(), axes_elems);
1139
1140 if (!result) return failure();
1141
1142 rewriter.replaceOp(op, {result.getValue()});
1143
1144 return success();
1145 }
1146
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1147 LogicalResult ConvertTFProdOp::matchAndRewrite(
1148 Operation* op, PatternRewriter& rewriter) const {
1149 auto tf_prod_op = cast<TF::ProdOp>(op);
1150
1151 RankedTensorType output_type =
1152 tf_prod_op.getResult().getType().dyn_cast<RankedTensorType>();
1153 if (!output_type) return failure();
1154
1155 ElementsAttr axes_elems;
1156 if (!matchPattern(tf_prod_op.reduction_indices(), m_Constant(&axes_elems)))
1157 return failure();
1158
1159 llvm::Optional<Value> result = convertReduceProdOp(
1160 rewriter, op, output_type, tf_prod_op.input(), axes_elems);
1161
1162 if (!result) return failure();
1163
1164 rewriter.replaceOp(op, {result.getValue()});
1165
1166 return success();
1167 }
1168
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1169 LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op,
1170 PatternRewriter& rewriter) const {
1171 auto tf_sum_op = cast<TF::SumOp>(op);
1172
1173 RankedTensorType output_type =
1174 tf_sum_op.getResult().getType().dyn_cast<RankedTensorType>();
1175 if (!output_type) return failure();
1176
1177 ElementsAttr axes_elems;
1178 if (!matchPattern(tf_sum_op.reduction_indices(), m_Constant(&axes_elems)))
1179 return failure();
1180
1181 llvm::Optional<Value> result = convertReduceSumOp(
1182 rewriter, op, output_type, tf_sum_op.input(), axes_elems);
1183
1184 if (!result) return failure();
1185
1186 rewriter.replaceOp(op, {result.getValue()});
1187
1188 return success();
1189 }
1190
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1191 LogicalResult ConvertTFEluOp::matchAndRewrite(Operation* op,
1192 PatternRewriter& rewriter) const {
1193 auto tf_elu_op = cast<TF::EluOp>(op);
1194
1195 llvm::Optional<Value> result =
1196 convertEluOp(rewriter, op, tf_elu_op.getResult(), tf_elu_op.features());
1197
1198 if (!result) return failure();
1199
1200 rewriter.replaceOp(op, {result.getValue()});
1201
1202 return success();
1203 }
1204
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1205 LogicalResult ConvertTFSoftmaxOp::matchAndRewrite(
1206 Operation* op, PatternRewriter& rewriter) const {
1207 auto tf_softmax_op = cast<TF::SoftmaxOp>(op);
1208
1209 llvm::Optional<Value> result =
1210 convertSoftmaxOp(rewriter, op, tf_softmax_op.getResult(),
1211 tf_softmax_op.logits(), /*beta=*/1.0);
1212
1213 if (!result) return failure();
1214
1215 rewriter.replaceOp(op, {result.getValue()});
1216
1217 return success();
1218 }
1219
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1220 LogicalResult ConvertTFLogSoftmaxOp::matchAndRewrite(
1221 Operation* op, PatternRewriter& rewriter) const {
1222 auto tf_logsoftmax_op = cast<TF::LogSoftmaxOp>(op);
1223
1224 llvm::Optional<Value> result = convertLogSoftmaxOp(
1225 rewriter, op, tf_logsoftmax_op.getResult(), tf_logsoftmax_op.logits());
1226
1227 if (!result) return failure();
1228
1229 rewriter.replaceOp(op, {result.getValue()});
1230
1231 return success();
1232 }
1233
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1234 LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite(
1235 Operation* op, PatternRewriter& rewriter) const {
1236 auto tf_batchnorm_op = cast<TF::FusedBatchNormOp>(op);
1237
1238 RankedTensorType output_type =
1239 tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1240 // Not a ranked tensor output
1241 if (!output_type) return failure();
1242
1243 // Lowering:
1244 // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
1245 //
1246 // shape_0 = ones(input.rank)
1247 // shape_0[input.rank-1] = input.shape[input.rank-1]
1248 // shape_1 = ones(1)
1249 //
1250 // bmean = reshape(mean, shape_0)
1251 // bscale = reshape(scale, shape_0)
1252 // boffset= reshape(offset, shape_0)
1253 // beps = reshape(epsilon, shape_1)
1254 //
1255 // op1 = sub(input, bmean)
1256 // op2 = add(var, beps)
1257 // op3 = rsqrt(op2)
1258 // bvar = reshape(op3, shape_0)
1259 // op4 = mul(op1, bvar)
1260 // op5 = mul(op4, bscale)
1261 // op6 = add(op5, boffset)
1262
1263 RankedTensorType mean_type =
1264 tf_batchnorm_op.mean().getType().dyn_cast<RankedTensorType>();
1265 RankedTensorType variance_type =
1266 tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
1267 if (!variance_type || !mean_type) return failure();
1268
1269 Value mean_val, variance_val;
1270
1271 if (mean_type.getNumElements() == 0) {
1272 mean_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 0);
1273 } else {
1274 mean_val = tf_batchnorm_op.mean();
1275 }
1276
1277 if (variance_type.getNumElements() == 0) {
1278 variance_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 1.0);
1279 } else {
1280 variance_val = tf_batchnorm_op.variance();
1281 }
1282
1283 RankedTensorType epsilon_type =
1284 RankedTensorType::get({1}, variance_type.getElementType());
1285 auto epsilon_attr =
1286 DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
1287 auto epsilon_const = CreateOpAndInfer<tosa::ConstOp>(
1288 rewriter, op->getLoc(), epsilon_type, epsilon_attr);
1289
1290 auto op1_sub_input_mean = CreateOpAndInfer<tosa::SubOp>(
1291 rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1292 tf_batchnorm_op.x(), mean_val);
1293
1294 auto op2_add_var_epsilon = CreateOpAndInfer<tosa::AddOp>(
1295 rewriter, op->getLoc(), variance_val.getType(), variance_val,
1296 epsilon_const.getResult());
1297
1298 auto op3_rsqrt_op2 = CreateOpAndInfer<tosa::RsqrtOp>(
1299 rewriter, op->getLoc(), variance_val.getType(),
1300 op2_add_var_epsilon.getResult());
1301
1302 auto op4_mul_op1_op3 = CreateOpAndInfer<tosa::MulOp>(
1303 rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1304 op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
1305
1306 auto op5_mul_op4_scale = CreateOpAndInfer<tosa::MulOp>(
1307 rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1308 op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
1309
1310 auto op6_add_op5_offset = CreateOpAndInfer<tosa::AddOp>(
1311 rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1312 op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
1313
1314 rewriter.replaceOp(op, {op6_add_op5_offset.getResult()});
1315 return success();
1316 }
1317
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1318 LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite(
1319 Operation* op, PatternRewriter& rewriter) const {
1320 auto tf_batchnorm_op = cast<TF::FusedBatchNormV3Op>(op);
1321
1322 if (tf_batchnorm_op.is_training())
1323 return rewriter.notifyMatchFailure(
1324 op, "unable to lower when is_training is set");
1325
1326 for (auto value : tf_batchnorm_op.getResults().drop_front(1)) {
1327 if (!value.use_empty()) {
1328 // Really we should compute this still and let it DCE but I can't find
1329 // the math.
1330 return rewriter.notifyMatchFailure(
1331 op, "lowering does not support aggregate statistics");
1332 }
1333 }
1334
1335 RankedTensorType output_type =
1336 tf_batchnorm_op.getResult(0).getType().dyn_cast<RankedTensorType>();
1337 // Not a ranked tensor output
1338 if (!output_type) return failure();
1339
1340 // Lowering:
1341 // fused batchnorm = (input-mean) * scale * rsqrt(var+epsilon)) + offset
1342 // op1 = sub(input, mean)
1343 // op2 = add(var, epsilon)
1344 // op3 = rsqrt(op2)
1345 // op4 = mul(op1, op3)
1346 // op5 = mul(op4, scale)
1347 // op6 = add(op5, offset)
1348
1349 auto op1_sub_input_mean = CreateOpAndInfer<tosa::SubOp>(
1350 rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1351 tf_batchnorm_op.x(), tf_batchnorm_op.mean());
1352
1353 RankedTensorType variance_type =
1354 tf_batchnorm_op.variance().getType().dyn_cast<RankedTensorType>();
1355 if (!variance_type) return failure();
1356
1357 auto epsilon_type =
1358 RankedTensorType::get({1}, variance_type.getElementType());
1359 auto epsilon_attr =
1360 DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()});
1361 auto epsilon_const = CreateOpAndInfer<tosa::ConstOp>(
1362 rewriter, op->getLoc(), epsilon_type, epsilon_attr);
1363
1364 auto op2_add_var_epsilon = CreateOpAndInfer<tosa::AddOp>(
1365 rewriter, op->getLoc(), tf_batchnorm_op.variance().getType(),
1366 tf_batchnorm_op.variance(), epsilon_const);
1367
1368 auto op3_rsqrt_op2 = CreateOpAndInfer<tosa::RsqrtOp>(
1369 rewriter, op->getLoc(), tf_batchnorm_op.variance().getType(),
1370 op2_add_var_epsilon.getResult());
1371
1372 auto op4_mul_op1_op3 = CreateOpAndInfer<tosa::MulOp>(
1373 rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1374 op1_sub_input_mean.getResult(), op3_rsqrt_op2.getResult(), 0);
1375
1376 auto op5_mul_op4_scale = CreateOpAndInfer<tosa::MulOp>(
1377 rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1378 op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0);
1379
1380 auto op6_add_op5_offset = CreateOpAndInfer<tosa::AddOp>(
1381 rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(),
1382 op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset());
1383
1384 llvm::SmallVector<Value> replacements = {
1385 op6_add_op5_offset.getResult(), tf_batchnorm_op.mean(),
1386 tf_batchnorm_op.variance(),
1387 // The last three are reserved spaces and have no purpose currently.
1388 tf_batchnorm_op.mean(), tf_batchnorm_op.variance(),
1389 tf_batchnorm_op.variance()};
1390 rewriter.replaceOp(op, replacements);
1391 return success();
1392 }
1393
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1394 LogicalResult ConvertTFBiasAddOp::matchAndRewrite(
1395 Operation* op, PatternRewriter& rewriter) const {
1396 auto tf_biasadd_op = cast<TF::BiasAddOp>(op);
1397
1398 RankedTensorType output_type =
1399 tf_biasadd_op.getResult().getType().dyn_cast<RankedTensorType>();
1400 // Not a ranked tensor output
1401 if (!output_type) return failure();
1402
1403 auto add_op = CreateOpAndInfer<tosa::AddOp>(
1404 rewriter, op->getLoc(), output_type, tf_biasadd_op.value(),
1405 tf_biasadd_op.bias());
1406
1407 rewriter.replaceOp(op, {add_op.getResult()});
1408 return success();
1409 }
1410
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1411 LogicalResult ConvertTFSliceOp::matchAndRewrite(
1412 Operation* op, PatternRewriter& rewriter) const {
1413 auto tf_slice_op = cast<TF::SliceOp>(op);
1414
1415 RankedTensorType output_type =
1416 tf_slice_op.getResult().getType().dyn_cast<RankedTensorType>();
1417 // Not a ranked tensor output
1418 if (!output_type) return failure();
1419
1420 ElementsAttr begin_elems, size_elems;
1421
1422 SmallVector<int64_t> begin_vals, size_vals;
1423
1424 // Assuming begin is always compile-time constant
1425 if (!matchPattern(tf_slice_op.begin(), m_Constant(&begin_elems))) {
1426 return op->emitOpError("TF::Slice error: begin is not constant");
1427 }
1428
1429 for (int i = 0; i < begin_elems.getNumElements(); i++)
1430 begin_vals.push_back(begin_elems.getValues<IntegerAttr>()[i].getInt());
1431
1432 // Try to match size as compile-time constant first,
1433 // if this fails, use the output tensor shape instead.
1434 if (matchPattern(tf_slice_op.size(), m_Constant(&size_elems))) {
1435 for (int i = 0; i < size_elems.getNumElements(); i++)
1436 size_vals.push_back(size_elems.getValues<IntegerAttr>()[i].getInt());
1437 } else {
1438 size_vals.assign(output_type.getShape().begin(),
1439 output_type.getShape().end());
1440 }
1441
1442 ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals);
1443 ArrayAttr size = rewriter.getI64ArrayAttr(size_vals);
1444
1445 CreateReplaceOpAndInfer<tosa::SliceOp>(rewriter, op, output_type,
1446 tf_slice_op.input(), begin, size);
1447 return success();
1448 }
1449
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1450 LogicalResult ConvertTFTileOp::matchAndRewrite(
1451 Operation* op, PatternRewriter& rewriter) const {
1452 auto tf_tile_op = cast<TF::TileOp>(op);
1453
1454 RankedTensorType output_type =
1455 tf_tile_op.getResult().getType().dyn_cast<RankedTensorType>();
1456 // Not a ranked tensor output
1457 if (!output_type) return failure();
1458
1459 ElementsAttr multiples_elems;
1460 if (!matchPattern(tf_tile_op.multiples(), m_Constant(&multiples_elems)))
1461 return failure();
1462 SmallVector<int64_t> multiples_vals;
1463 for (int i = 0; i < multiples_elems.getNumElements(); i++)
1464 multiples_vals.push_back(
1465 multiples_elems.getValues<IntegerAttr>()[i].getInt());
1466
1467 ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals);
1468
1469 CreateReplaceOpAndInfer<tosa::TileOp>(rewriter, op, output_type,
1470 tf_tile_op.input(), multiples_attr);
1471
1472 return success();
1473 }
1474
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1475 LogicalResult ConvertTFTransposeOp::matchAndRewrite(
1476 Operation* op, PatternRewriter& rewriter) const {
1477 auto tf_transpose_op = cast<TF::TransposeOp>(op);
1478
1479 TensorType output_type =
1480 tf_transpose_op.getResult().getType().dyn_cast<TensorType>();
1481 // Not a ranked tensor output
1482 if (!output_type) {
1483 return failure();
1484 }
1485
1486 CreateReplaceOpAndInfer<tosa::TransposeOp>(
1487 rewriter, op, output_type, tf_transpose_op.x(), tf_transpose_op.perm());
1488
1489 return success();
1490 }
1491
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1492 LogicalResult ConvertTFPackOp::matchAndRewrite(
1493 Operation* op, PatternRewriter& rewriter) const {
1494 auto tf_pack_op = cast<TF::PackOp>(op);
1495
1496 SmallVector<Value> inputs(tf_pack_op.values());
1497
1498 assert(inputs.size() >= 2);
1499
1500 IntegerAttr axis_attr = tf_pack_op.axisAttr();
1501 if (!axis_attr) axis_attr = rewriter.getI64IntegerAttr(0);
1502
1503 int32_t axis_i32 = axis_attr.getInt();
1504
1505 llvm::Optional<Value> result =
1506 convertPackOp(rewriter, op, tf_pack_op.getResult(), inputs, axis_i32);
1507
1508 if (!result) return failure();
1509
1510 rewriter.replaceOp(op, {result.getValue()});
1511
1512 return success();
1513 }
1514
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1515 LogicalResult ConvertTFUnpackOp::matchAndRewrite(
1516 Operation* op, PatternRewriter& rewriter) const {
1517 auto tf_unpack_op = cast<TF::UnpackOp>(op);
1518
1519 IntegerAttr axis_attr;
1520 {
1521 auto tmpAttr = tf_unpack_op.axisAttr();
1522 if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
1523 axis_attr = tmpAttr;
1524 }
1525 int32_t axis_i32 = axis_attr.getInt();
1526
1527 llvm::Optional<SmallVector<Value>> results =
1528 convertUnpackOp(rewriter, op, tf_unpack_op.value(), axis_i32);
1529
1530 if (!results) return failure();
1531
1532 rewriter.replaceOp(op, results.getValue());
1533
1534 return success();
1535 }
1536
1537 // Splits in num_split parts along split_dim
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1538 LogicalResult ConvertTFSplitOp::matchAndRewrite(
1539 Operation* op, PatternRewriter& rewriter) const {
1540 auto tf_split_op = cast<TF::SplitOp>(op);
1541
1542 // Get the number of splits
1543 int32_t num_split = -1;
1544
1545 auto range = tf_split_op.getODSResults(0);
1546 num_split = std::distance(range.begin(), range.end());
1547
1548 // Get the axis
1549 int32_t axis = 0;
1550 ElementsAttr axisAttrElems;
1551 if (matchPattern(tf_split_op.split_dim(), m_Constant(&axisAttrElems))) {
1552 axis = axisAttrElems.getValues<IntegerAttr>()[0].getInt();
1553 }
1554
1555 llvm::Optional<SmallVector<Value>> results =
1556 convertSplitOp(rewriter, op, tf_split_op.getResult(0),
1557 tf_split_op.value(), num_split, axis);
1558
1559 if (!results) return failure();
1560
1561 rewriter.replaceOp(op, results.getValue());
1562
1563 return success();
1564 }
1565
1566 // TFSplitV op splits based on a vector of sizes
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1567 LogicalResult ConvertTFSplitVOp::matchAndRewrite(
1568 Operation* op, PatternRewriter& rewriter) const {
1569 auto tf_splitv_op = cast<TF::SplitVOp>(op);
1570
1571 // Get the size_splits array
1572 SmallVector<int32_t> size_split;
1573 ElementsAttr size_split_elems;
1574 if (!matchPattern(tf_splitv_op.size_splits(),
1575 m_Constant(&size_split_elems))) {
1576 return failure();
1577 }
1578
1579 for (int i = 0; i < size_split_elems.getNumElements(); i++) {
1580 size_split.push_back(size_split_elems.getValues<IntegerAttr>()[i].getInt());
1581 }
1582
1583 // Get the axis
1584 ElementsAttr axisAttrElems;
1585 if (!matchPattern(tf_splitv_op.split_dim(), m_Constant(&axisAttrElems))) {
1586 return op->emitOpError("Cannot read split_dim elems");
1587 }
1588
1589 int32_t axis = axisAttrElems.getValues<IntegerAttr>()[0].getInt();
1590
1591 llvm::Optional<SmallVector<Value>> results =
1592 convertSplitVOp(rewriter, op, tf_splitv_op.getResult(0),
1593 tf_splitv_op.value(), size_split, axis);
1594
1595 if (!results) return failure();
1596
1597 rewriter.replaceOp(op, results.getValue());
1598
1599 return success();
1600 }
1601
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1602 LogicalResult ConvertTFLessOp::matchAndRewrite(
1603 Operation* op, PatternRewriter& rewriter) const {
1604 auto tf_less_op = cast<TF::LessOp>(op);
1605
1606 TensorType output_type =
1607 tf_less_op.getResult().getType().dyn_cast<TensorType>();
1608 // Not a ranked tensor output
1609 if (!output_type) return failure();
1610
1611 // less(x, y) is not(greater_equal(x, y))
1612 auto greater_equal_op = CreateOpAndInfer<tosa::GreaterEqualOp>(
1613 rewriter, op->getLoc(), output_type, tf_less_op.x(), tf_less_op.y());
1614
1615 auto not_op = CreateOpAndInfer<tosa::LogicalNotOp>(
1616 rewriter, op->getLoc(), output_type, greater_equal_op.getResult());
1617
1618 rewriter.replaceOp(op, {not_op.getResult()});
1619 return success();
1620 }
1621
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1622 LogicalResult ConvertTFLessEqualOp::matchAndRewrite(
1623 Operation* op, PatternRewriter& rewriter) const {
1624 auto tf_less_equal_op = cast<TF::LessEqualOp>(op);
1625
1626 TensorType output_type =
1627 tf_less_equal_op.getResult().getType().dyn_cast<TensorType>();
1628 // Not a ranked tensor output
1629 if (!output_type) return failure();
1630
1631 // less_equal(x, y) is not(greater(x, y))
1632 auto greater_op = CreateOpAndInfer<tosa::GreaterOp>(
1633 rewriter, op->getLoc(), output_type, tf_less_equal_op.x(),
1634 tf_less_equal_op.y());
1635 auto not_op = CreateOpAndInfer<tosa::LogicalNotOp>(
1636 rewriter, op->getLoc(), output_type, greater_op.getResult());
1637
1638 rewriter.replaceOp(op, {not_op.getResult()});
1639 return success();
1640 }
1641
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1642 LogicalResult ConvertTFPadOp::matchAndRewrite(Operation* op,
1643 PatternRewriter& rewriter) const {
1644 auto tf_pad_op = cast<TF::PadOp>(op);
1645
1646 TensorType output_type =
1647 tf_pad_op.getResult().getType().dyn_cast<TensorType>();
1648 // Not a ranked tensor output
1649 if (!output_type) return failure();
1650
1651 auto pad_op =
1652 CreateOpAndInfer<tosa::PadOp>(rewriter, op->getLoc(), output_type,
1653 tf_pad_op.input(), tf_pad_op.paddings());
1654
1655 rewriter.replaceOp(op, {pad_op.getResult()});
1656 return success();
1657 }
1658
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1659 LogicalResult ConvertTFResizeBilinearOp::matchAndRewrite(
1660 Operation* op, PatternRewriter& rewriter) const {
1661 auto tf_resize_op = cast<TF::ResizeBilinearOp>(op);
1662
1663 RankedTensorType output_type =
1664 tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
1665 // Not a ranked tensor output
1666 if (!output_type) return failure();
1667
1668 llvm::Optional<Value> result = convertResizeOp(
1669 rewriter, op, output_type, tf_resize_op.images(), StringRef("BILINEAR"),
1670 tf_resize_op.align_cornersAttr().getValue(),
1671 tf_resize_op.half_pixel_centersAttr().getValue());
1672
1673 if (!result) return failure();
1674
1675 rewriter.replaceOp(op, {result.getValue()});
1676
1677 return success();
1678 }
1679
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1680 LogicalResult ConvertTFResizeNearestNeighborOp::matchAndRewrite(
1681 Operation* op, PatternRewriter& rewriter) const {
1682 auto tf_resize_op = cast<TF::ResizeNearestNeighborOp>(op);
1683
1684 RankedTensorType output_type =
1685 tf_resize_op.getResult().getType().dyn_cast<RankedTensorType>();
1686 // Not a ranked tensor output
1687 if (!output_type) return failure();
1688
1689 llvm::Optional<Value> result =
1690 convertResizeOp(rewriter, op, output_type, tf_resize_op.images(),
1691 StringRef("NEAREST_NEIGHBOR"),
1692 tf_resize_op.align_cornersAttr().getValue(),
1693 tf_resize_op.half_pixel_centersAttr().getValue());
1694
1695 if (!result) return failure();
1696
1697 rewriter.replaceOp(op, {result.getValue()});
1698
1699 return success();
1700 }
1701
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1702 LogicalResult ConvertTFMatMulOp::matchAndRewrite(
1703 Operation* op, PatternRewriter& rewriter) const {
1704 auto tf_matmul_op = cast<TF::MatMulOp>(op);
1705
1706 RankedTensorType a_type =
1707 tf_matmul_op.a().getType().dyn_cast<RankedTensorType>();
1708 RankedTensorType b_type =
1709 tf_matmul_op.b().getType().dyn_cast<RankedTensorType>();
1710 RankedTensorType output_type =
1711 tf_matmul_op.getResult().getType().dyn_cast<RankedTensorType>();
1712
1713 if (!(a_type && b_type && output_type)) {
1714 return op->emitOpError("MatMul: a/b/output not ranked tensors");
1715 }
1716
1717 if (a_type.getRank() != b_type.getRank() ||
1718 a_type.getRank() != output_type.getRank()) {
1719 return op->emitOpError("MatMul: a/b/output rank must match");
1720 }
1721
1722 // Can only handle rank 2 tensors for tf.MatMul.
1723 // Cases with rank > 2 tensors should be handled by tf.BatchMatMul or
1724 // tf.BatchMatMulV2
1725 if (a_type.getRank() != 2) {
1726 return op->emitOpError("MatMul: a/b/output rank must be 2");
1727 }
1728
1729 SmallVector<int64_t, 3> batch_a_shape(
1730 {1, a_type.getShape()[0], a_type.getShape()[1]});
1731 SmallVector<int64_t, 3> batch_b_shape(
1732 {1, b_type.getShape()[0], b_type.getShape()[1]});
1733 SmallVector<int64_t, 3> batch_output_shape(
1734 {1, output_type.getShape()[0], output_type.getShape()[1]});
1735
1736 RankedTensorType batch_a_type =
1737 RankedTensorType::get(batch_a_shape, a_type.getElementType());
1738 RankedTensorType batch_b_type =
1739 RankedTensorType::get(batch_b_shape, b_type.getElementType());
1740 RankedTensorType batch_output_type =
1741 RankedTensorType::get(batch_output_shape, output_type.getElementType());
1742
1743 // Need to reshape input and output since TOSA matmul only supports
1744 // [N, H, C] * [N, C, W] -> [N, H, W].
1745 auto op1_reshape_a = CreateOpAndInfer<tosa::ReshapeOp>(
1746 rewriter, op->getLoc(), batch_a_type, tf_matmul_op.a(),
1747 rewriter.getI64ArrayAttr(batch_a_shape));
1748
1749 auto op2_reshape_b = CreateOpAndInfer<tosa::ReshapeOp>(
1750 rewriter, op->getLoc(), batch_b_type, tf_matmul_op.b(),
1751 rewriter.getI64ArrayAttr(batch_b_shape));
1752
1753 auto op3_matmul_op1_op2 = CreateOpAndInfer<tosa::MatMulOp>(
1754 rewriter, op->getLoc(), batch_output_type, op1_reshape_a.getResult(),
1755 op2_reshape_b.getResult());
1756
1757 CreateReplaceOpAndInfer<tosa::ReshapeOp>(
1758 rewriter, op, output_type, op3_matmul_op1_op2.getResult(),
1759 rewriter.getI64ArrayAttr(output_type.getShape()));
1760
1761 return success();
1762 }
1763
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1764 LogicalResult ConvertTFGatherOp::matchAndRewrite(
1765 Operation* op, PatternRewriter& rewriter) const {
1766 auto tf_gather_op = cast<TF::GatherOp>(op);
1767
1768 // tf.Gather is equivalent to tf.GatherV2 with batch_dims = 0, axis = 0
1769 int32_t batch_dims = 0;
1770 int32_t axis = 0;
1771
1772 llvm::Optional<Value> result = convertGatherOp(
1773 rewriter, op, tf_gather_op.getResult(), tf_gather_op.params(),
1774 tf_gather_op.indices(), batch_dims, axis);
1775
1776 if (!result) return failure();
1777
1778 rewriter.replaceOp(op, {result.getValue()});
1779
1780 return success();
1781 }
1782
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1783 LogicalResult ConvertTFGatherV2Op::matchAndRewrite(
1784 Operation* op, PatternRewriter& rewriter) const {
1785 auto tf_gather_op = cast<TF::GatherV2Op>(op);
1786
1787 // Axis is a tensor. Pull out the one integer value.
1788 ElementsAttr axis_elem;
1789 if (!matchPattern(tf_gather_op.axis(), m_Constant(&axis_elem)))
1790 return failure();
1791 assert(axis_elem.getNumElements() == 1);
1792
1793 int32_t axis = axis_elem.getValues<IntegerAttr>()[0].getInt();
1794 int32_t batch_dims = tf_gather_op.batch_dimsAttr().getInt();
1795
1796 llvm::Optional<Value> result = convertGatherOp(
1797 rewriter, op, tf_gather_op.getResult(), tf_gather_op.params(),
1798 tf_gather_op.indices(), batch_dims, axis);
1799
1800 if (!result) return failure();
1801
1802 rewriter.replaceOp(op, {result.getValue()});
1803
1804 return success();
1805 }
1806
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1807 LogicalResult ConvertTFGatherNdOp::matchAndRewrite(
1808 Operation* op, PatternRewriter& rewriter) const {
1809 auto tf_gathernd_op = cast<TF::GatherNdOp>(op);
1810
1811 llvm::Optional<Value> result =
1812 convertGatherNdOp(rewriter, op, tf_gathernd_op.getResult(),
1813 tf_gathernd_op.params(), tf_gathernd_op.indices());
1814
1815 if (!result) return failure();
1816
1817 rewriter.replaceOp(op, {result.getValue()});
1818
1819 return success();
1820 }
1821
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1822 LogicalResult ConvertTFSelectV2Op::matchAndRewrite(
1823 Operation* op, PatternRewriter& rewriter) const {
1824 auto tf_sel_op = cast<TF::SelectV2Op>(op);
1825
1826 llvm::Optional<Value> result =
1827 convertSelectOp(rewriter, op, tf_sel_op.getResult(),
1828 tf_sel_op.condition(), tf_sel_op.t(), tf_sel_op.e());
1829
1830 if (!result) return failure();
1831
1832 rewriter.replaceOp(op, {result.getValue()});
1833
1834 return success();
1835 }
1836
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1837 LogicalResult ConvertTFSpaceToDepthOp::matchAndRewrite(
1838 Operation* op, PatternRewriter& rewriter) const {
1839 auto tf_s2d_op = cast<TF::SpaceToDepthOp>(op);
1840
1841 llvm::Optional<Value> result = convertSpaceToDepthOp(
1842 rewriter, op, tf_s2d_op.getResult(), tf_s2d_op.input(),
1843 tf_s2d_op.block_sizeAttr(), tf_s2d_op.data_formatAttr());
1844
1845 if (!result) return failure();
1846
1847 rewriter.replaceOp(op, {result.getValue()});
1848
1849 return success();
1850 }
1851
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1852 LogicalResult ConvertTFDepthToSpaceOp::matchAndRewrite(
1853 Operation* op, PatternRewriter& rewriter) const {
1854 auto tf_d2s_op = cast<TF::DepthToSpaceOp>(op);
1855
1856 llvm::Optional<Value> result = convertDepthToSpaceOp(
1857 rewriter, op, tf_d2s_op.getResult(), tf_d2s_op.input(),
1858 tf_d2s_op.block_sizeAttr(), tf_d2s_op.data_formatAttr());
1859
1860 if (!result) return failure();
1861
1862 rewriter.replaceOp(op, {result.getValue()});
1863
1864 return success();
1865 }
1866
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1867 LogicalResult ConvertTFSpaceToBatchNDOp::matchAndRewrite(
1868 Operation* op, PatternRewriter& rewriter) const {
1869 auto tf_s2b_op = cast<TF::SpaceToBatchNDOp>(op);
1870
1871 llvm::Optional<Value> result = convertSpaceToBatchNDOp(
1872 rewriter, op, tf_s2b_op.getResult(), tf_s2b_op.input(),
1873 tf_s2b_op.block_shape(), tf_s2b_op.paddings());
1874 if (!result) return failure();
1875
1876 rewriter.replaceOp(op, {result.getValue()});
1877
1878 return success();
1879 }
1880
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1881 LogicalResult ConvertTFBatchToSpaceNDOp::matchAndRewrite(
1882 Operation* op, PatternRewriter& rewriter) const {
1883 auto tf_b2s_op = cast<TF::BatchToSpaceNDOp>(op);
1884
1885 llvm::Optional<Value> result = convertBatchToSpaceNDOp(
1886 rewriter, op, tf_b2s_op.getResult(), tf_b2s_op.input(),
1887 tf_b2s_op.block_shape(), tf_b2s_op.crops());
1888
1889 if (!result) return failure();
1890
1891 rewriter.replaceOp(op, {result.getValue()});
1892
1893 return success();
1894 }
1895
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1896 LogicalResult ConvertTFStridedSliceOp::matchAndRewrite(
1897 Operation* op, PatternRewriter& rewriter) const {
1898 auto tf_ss_op = cast<TF::StridedSliceOp>(op);
1899
1900 llvm::Optional<Value> result = convertStridedSliceOp(
1901 rewriter, op, tf_ss_op.getResult(), tf_ss_op.input(), tf_ss_op.begin(),
1902 tf_ss_op.end(), tf_ss_op.strides(), tf_ss_op.begin_maskAttr().getInt(),
1903 tf_ss_op.end_maskAttr().getInt(), tf_ss_op.ellipsis_maskAttr().getInt(),
1904 tf_ss_op.new_axis_maskAttr().getInt(),
1905 tf_ss_op.shrink_axis_maskAttr().getInt());
1906
1907 if (!result) return failure();
1908
1909 rewriter.replaceOp(op, {result.getValue()});
1910
1911 return success();
1912 }
1913
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1914 LogicalResult ConvertTFZerosLikeOp::matchAndRewrite(
1915 Operation* op, PatternRewriter& rewriter) const {
1916 auto tf_zeroslike_op = cast<TF::ZerosLikeOp>(op);
1917
1918 llvm::Optional<Value> result = convertZerosLikeOp(
1919 rewriter, op, tf_zeroslike_op.getResult(), tf_zeroslike_op.x());
1920
1921 if (!result) return failure();
1922
1923 rewriter.replaceOp(op, {result.getValue()});
1924
1925 return success();
1926 }
1927
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1928 LogicalResult ConvertTFSigmoidOp::matchAndRewrite(
1929 Operation* op, PatternRewriter& rewriter) const {
1930 auto tf_sigmoid_op = cast<TF::SigmoidOp>(op);
1931 TensorType output_type =
1932 tf_sigmoid_op.getResult().getType().dyn_cast<TensorType>();
1933 if (!output_type) return failure();
1934
1935 CreateReplaceOpAndInfer<tosa::SigmoidOp>(rewriter, op, output_type,
1936 tf_sigmoid_op.x());
1937
1938 return success();
1939 }
1940
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1941 LogicalResult ConvertTFTanhOp::matchAndRewrite(
1942 Operation* op, PatternRewriter& rewriter) const {
1943 auto tf_tanh_op = cast<TF::TanhOp>(op);
1944 TensorType output_type =
1945 tf_tanh_op.getResult().getType().dyn_cast<TensorType>();
1946 if (!output_type) return failure();
1947
1948 CreateReplaceOpAndInfer<tosa::TanhOp>(rewriter, op, output_type,
1949 tf_tanh_op.x());
1950
1951 return success();
1952 }
1953
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1954 LogicalResult ConvertTFLeakyReluOp::matchAndRewrite(
1955 Operation* op, PatternRewriter& rewriter) const {
1956 auto tf_leakyrelu_op = cast<TF::LeakyReluOp>(op);
1957 TensorType output_type =
1958 tf_leakyrelu_op.getResult().getType().dyn_cast<TensorType>();
1959 if (!output_type) return failure();
1960
1961 // Implement LeakyRelu as element-wise:
1962 // out = x > 0 ? x : alpha * x
1963 //
1964 // In TOSA ops:
1965 //
1966 // const_zero = constant(0)
1967 // a1 = mul(x, alpha)
1968 // a2 = greater_equal(x, const_zero)
1969 // out = select(a2, x, a1)
1970 //
1971 // If alpha can be constrained to 0.0 <= alpha <= 1.0, then
1972 // an alternative simpler lowering could be implemented with:
1973 //
1974 // max(mul(x, alapha), x)
1975 //
1976 // But this alternative is not robust unless alpha meets those constraints.
1977
1978 if (!output_type.getElementType().isF32()) {
1979 op->emitOpError("ConvertTFLeakyReluOp: only support F32");
1980 return failure();
1981 }
1982
1983 FloatAttr tmpAttr = tf_leakyrelu_op.alphaAttr();
1984 // There is disagreement between the MLIR .td defaults and TF
1985 // documentation on 0.2 vs 0.3, but 0.2 will be used here.
1986 double alpha = 0.2;
1987
1988 if (tmpAttr) {
1989 alpha = tmpAttr.getValueAsDouble();
1990 }
1991
1992 Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0);
1993
1994 auto a1_mul = CreateOpAndInfer<tosa::MulOp>(
1995 rewriter, op->getLoc(), output_type, tf_leakyrelu_op.features(),
1996 getTosaConstTensorSingleF32(rewriter, op, alpha), 0);
1997
1998 auto a2_ge = CreateOpAndInfer<tosa::GreaterEqualOp>(
1999 rewriter, op->getLoc(), UnrankedTensorType::get(rewriter.getI1Type()),
2000 tf_leakyrelu_op.features(), const_zero);
2001
2002 auto a3_select = CreateOpAndInfer<tosa::SelectOp>(
2003 rewriter, op->getLoc(), output_type, a2_ge, tf_leakyrelu_op.features(),
2004 a1_mul.getResult());
2005
2006 rewriter.replaceOp(op, {a3_select.getResult()});
2007
2008 return success();
2009 }
2010
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2011 LogicalResult ConvertTFNegOp::matchAndRewrite(Operation* op,
2012 PatternRewriter& rewriter) const {
2013 auto tf_neg_op = cast<TF::NegOp>(op);
2014 TensorType output_type =
2015 tf_neg_op.getResult().getType().dyn_cast<TensorType>();
2016 if (!output_type) return failure();
2017
2018 CreateReplaceOpAndInfer<tosa::NegateOp>(rewriter, op, output_type,
2019 tf_neg_op.x());
2020
2021 return success();
2022 }
2023
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2024 LogicalResult ConvertTFStopGradientOp::matchAndRewrite(
2025 Operation* op, PatternRewriter& rewriter) const {
2026 auto tf_stopgrad_op = cast<TF::StopGradientOp>(op);
2027 TensorType output_type =
2028 tf_stopgrad_op.getResult().getType().dyn_cast<TensorType>();
2029 if (!output_type) return failure();
2030
2031 CreateReplaceOpAndInfer<tosa::IdentityOp>(rewriter, op, output_type,
2032 tf_stopgrad_op.input());
2033
2034 return success();
2035 }
2036
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2037 LogicalResult ConvertTFReverseV2Op::matchAndRewrite(
2038 Operation* op, PatternRewriter& rewriter) const {
2039 auto tf_reverse_op = cast<TF::ReverseV2Op>(op);
2040 RankedTensorType input_type =
2041 tf_reverse_op.tensor().getType().dyn_cast<RankedTensorType>();
2042 TensorType output_type =
2043 tf_reverse_op.getResult().getType().dyn_cast<TensorType>();
2044 if (!input_type || !output_type) return failure();
2045
2046 ElementsAttr axis_elems;
2047 if (!matchPattern(tf_reverse_op.axis(), m_Constant(&axis_elems)))
2048 return failure();
2049
2050 auto input_rank = input_type.getShape().size();
2051 Value val = tf_reverse_op.tensor();
2052 if (axis_elems.getNumElements() == 0) {
2053 auto identity_op = CreateOpAndInfer<tosa::IdentityOp>(
2054 rewriter, op->getLoc(), output_type, val);
2055 val = identity_op.getResult();
2056 } else {
2057 for (int i = 0; i < axis_elems.getNumElements(); i++) {
2058 int64_t axis_val = axis_elems.getValues<IntegerAttr>()[i].getInt();
2059 if (axis_val < 0) axis_val += input_rank;
2060 auto axis_attr = rewriter.getI64IntegerAttr(axis_val);
2061 auto reverse_op = CreateOpAndInfer<tosa::ReverseOp>(
2062 rewriter, op->getLoc(), output_type, val, axis_attr);
2063
2064 val = reverse_op.getResult();
2065 }
2066 }
2067
2068 rewriter.replaceOp(op, {val});
2069
2070 return success();
2071 }
2072
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2073 LogicalResult ConvertTFFakeQuantWithMinMaxArgsOp::matchAndRewrite(
2074 Operation* op, PatternRewriter& rewriter) const {
2075 auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxArgsOp>(op);
2076
2077 TensorType output_type =
2078 tf_fakequant_op.getResult().getType().dyn_cast<TensorType>();
2079 // Not a tensor output
2080 if (!output_type) return failure();
2081
2082 llvm::Optional<Value> result =
2083 convertFakeQuantOp(rewriter, op, output_type, tf_fakequant_op.inputs(),
2084 tf_fakequant_op.minAttr().getValueAsDouble(),
2085 tf_fakequant_op.maxAttr().getValueAsDouble(),
2086 tf_fakequant_op.num_bitsAttr().getInt(),
2087 tf_fakequant_op.narrow_rangeAttr().getValue());
2088
2089 if (!result) return failure();
2090
2091 rewriter.replaceOp(op, {result.getValue()});
2092
2093 return success();
2094 }
2095
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2096 LogicalResult ConvertTFFakeQuantWithMinMaxVarsOp::matchAndRewrite(
2097 Operation* op, PatternRewriter& rewriter) const {
2098 auto tf_fakequant_op = cast<TF::FakeQuantWithMinMaxVarsOp>(op);
2099
2100 TensorType output_type =
2101 tf_fakequant_op.getResult().getType().dyn_cast<TensorType>();
2102 // Not a tensor output
2103 if (!output_type) return failure();
2104
2105 // Only support min/max that can be matched at compile time
2106 ElementsAttr min_elems, max_elems;
2107 if (!matchPattern(tf_fakequant_op.min(), m_Constant(&min_elems)))
2108 return failure();
2109
2110 if (!matchPattern(tf_fakequant_op.max(), m_Constant(&max_elems)))
2111 return failure();
2112
2113 if (min_elems.getNumElements() != 1 && max_elems.getNumElements() != 1)
2114 return failure();
2115
2116 int64_t min_val = min_elems.getValues<IntegerAttr>()[0].getInt();
2117 int64_t max_val = max_elems.getValues<IntegerAttr>()[0].getInt();
2118
2119 llvm::Optional<Value> result = convertFakeQuantOp(
2120 rewriter, op, output_type, tf_fakequant_op.inputs(), min_val, max_val,
2121 tf_fakequant_op.num_bitsAttr().getInt(),
2122 tf_fakequant_op.narrow_rangeAttr().getValue());
2123
2124 if (!result) return failure();
2125
2126 rewriter.replaceOp(op, {result.getValue()});
2127
2128 return success();
2129 }
2130
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2131 LogicalResult ConvertTFLeftShiftOp::matchAndRewrite(
2132 Operation* op, PatternRewriter& rewriter) const {
2133 auto tf_left_shift_op = cast<TF::LeftShiftOp>(op);
2134
2135 TensorType output_type =
2136 tf_left_shift_op.getResult().getType().dyn_cast<TensorType>();
2137 if (!output_type) return failure();
2138
2139 CreateReplaceOpAndInfer<tosa::LogicalLeftShiftOp>(
2140 rewriter, op, output_type, tf_left_shift_op.x(), tf_left_shift_op.y());
2141
2142 return success();
2143 }
2144
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2145 LogicalResult ConvertTFRightShiftOp::matchAndRewrite(
2146 Operation* op, PatternRewriter& rewriter) const {
2147 // Performs a logical shift for unsigned integer types, and an arithmetic
2148 // shift for signed integer types.
2149 auto tf_right_shift_op = cast<TF::RightShiftOp>(op);
2150
2151 TensorType output_type =
2152 tf_right_shift_op.getResult().getType().dyn_cast<TensorType>();
2153 if (!output_type) return failure();
2154
2155 Type output_element_type = output_type.getElementType();
2156
2157 bool is_signed = false;
2158 if (!output_element_type.isUnsignedInteger()) is_signed = true;
2159
2160 if (is_signed) {
2161 CreateReplaceOpAndInfer<tosa::ArithmeticRightShiftOp>(
2162 rewriter, op, output_type, tf_right_shift_op.x(), tf_right_shift_op.y(),
2163 false);
2164 } else {
2165 CreateReplaceOpAndInfer<tosa::LogicalRightShiftOp>(
2166 rewriter, op, output_type, tf_right_shift_op.x(),
2167 tf_right_shift_op.y());
2168 }
2169
2170 return success();
2171 }
2172
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2173 LogicalResult ConvertTFOneHotOp::matchAndRewrite(
2174 Operation* op, PatternRewriter& rewriter) const {
2175 auto tf_one_hot_op = cast<TF::OneHotOp>(op);
2176
2177 ElementsAttr depth_elems;
2178 if (!matchPattern(tf_one_hot_op.depth(), m_Constant(&depth_elems)))
2179 return failure();
2180 int32_t depth = depth_elems.getValues<IntegerAttr>()[0].getInt();
2181
2182 IntegerAttr axisAttr = tf_one_hot_op.axisAttr();
2183 int32_t axis = axisAttr.getInt();
2184
2185 llvm::Optional<Value> result = convertOneHotOp(
2186 rewriter, op, tf_one_hot_op.getResult(), tf_one_hot_op.indices(),
2187 tf_one_hot_op.on_value(), tf_one_hot_op.off_value(), depth, axis);
2188
2189 if (!result) return failure();
2190
2191 rewriter.replaceOp(op, {result.getValue()});
2192
2193 return success();
2194 }
2195
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2196 LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite(
2197 Operation* op, PatternRewriter& rewriter) const {
2198 auto tf_batch_matmul_op = cast<TF::BatchMatMulV2Op>(op);
2199
2200 RankedTensorType x_type =
2201 tf_batch_matmul_op.x().getType().dyn_cast<RankedTensorType>();
2202 RankedTensorType y_type =
2203 tf_batch_matmul_op.y().getType().dyn_cast<RankedTensorType>();
2204 RankedTensorType output_type =
2205 tf_batch_matmul_op.getResult().getType().dyn_cast<RankedTensorType>();
2206
2207 if (!(x_type && y_type && output_type)) {
2208 return op->emitOpError("BatchMatMulV2: x/y/output not ranked tensors");
2209 }
2210
2211 if (x_type.getRank() != y_type.getRank() ||
2212 x_type.getRank() != output_type.getRank()) {
2213 return op->emitOpError("BatchMatMulV2: x/y/output rank must match");
2214 }
2215
2216 if (x_type.getRank() <= 2) {
2217 return op->emitOpError("BatchMatMulV2: x/y/output rank must > 2");
2218 }
2219
2220 // Rank 3 batch matmul can be directly mapped to tosa.matmul trivially.
2221 if (x_type.getRank() == 3) {
2222 CreateReplaceOpAndInfer<tosa::MatMulOp>(rewriter, op, output_type,
2223 tf_batch_matmul_op.x(),
2224 tf_batch_matmul_op.y());
2225 } else {
2226 // 1. Reshape x from: (similar for y)
2227 // [a0, a1, ... an, H, C] to [N, H, C].
2228 // where N = a0 * a1 * ... * an.
2229 // 2. tosa.MatMul
2230 // [N, H, C] * [N, C, W] -> [N, H, W].
2231 // 3. Reshape output from:
2232 // [N, H, W] to [a0, a1, ... , an, H, W]
2233 int64_t rank = x_type.getRank();
2234 int64_t N = 1;
2235 for (int i = 0; i < (rank - 2); i++) {
2236 N *= x_type.getShape()[i];
2237 }
2238 int64_t H = x_type.getShape()[rank - 2];
2239 int64_t C = x_type.getShape()[rank - 1];
2240 int64_t W = y_type.getShape()[rank - 1];
2241
2242 SmallVector<int64_t, 3> rank3_x_shape({N, H, C});
2243 SmallVector<int64_t, 3> rank3_y_shape({N, C, W});
2244 SmallVector<int64_t, 3> rank3_output_shape({N, H, W});
2245
2246 RankedTensorType rank3_x_type =
2247 RankedTensorType::get(rank3_x_shape, x_type.getElementType());
2248 RankedTensorType rank3_y_type =
2249 RankedTensorType::get(rank3_y_shape, y_type.getElementType());
2250 RankedTensorType rank3_output_type =
2251 RankedTensorType::get(rank3_output_shape, output_type.getElementType());
2252
2253 auto op1_reshape_x = CreateOpAndInfer<tosa::ReshapeOp>(
2254 rewriter, op->getLoc(), rank3_x_type, tf_batch_matmul_op.x(),
2255 rewriter.getI64ArrayAttr(rank3_x_shape));
2256
2257 auto op2_reshape_y = CreateOpAndInfer<tosa::ReshapeOp>(
2258 rewriter, op->getLoc(), rank3_y_type, tf_batch_matmul_op.y(),
2259 rewriter.getI64ArrayAttr(rank3_y_shape));
2260
2261 auto op3_matmul_op1_op2 = CreateOpAndInfer<tosa::MatMulOp>(
2262 rewriter, op->getLoc(), rank3_output_type, op1_reshape_x.getResult(),
2263 op2_reshape_y.getResult());
2264
2265 CreateReplaceOpAndInfer<tosa::ReshapeOp>(
2266 rewriter, op, output_type, op3_matmul_op1_op2.getResult(),
2267 rewriter.getI64ArrayAttr(output_type.getShape()));
2268 }
2269 return success();
2270 }
2271
runOnOperation()2272 void LegalizeTF::runOnOperation() {
2273 auto* ctx = &getContext();
2274 RewritePatternSet patterns(ctx);
2275 auto func = getOperation();
2276 populateLegalizeTFPatterns(ctx, patterns);
2277
2278 if (ApplyPatternsWithShapeResolution(func, std::move(patterns)).failed()) {
2279 signalPassFailure();
2280 }
2281 }
2282
2283 } // anonymous namespace
2284
populateLegalizeTFPatterns(MLIRContext * ctx,RewritePatternSet & patterns)2285 void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) {
2286 // Add the generated patterns to the list.
2287 populateWithGenerated(patterns);
2288 patterns.add<ConvertTFMatMulOp>(ctx);
2289 patterns.add<ConvertTFReluOp>(ctx);
2290 patterns.add<ConvertTFRelu6Op>(ctx);
2291 patterns.add<ConvertTFEqualOp>(ctx);
2292 patterns.add<ConvertTFNotEqualOp>(ctx);
2293 patterns.add<ConvertTFGreaterOp>(ctx);
2294 patterns.add<ConvertTFGreaterEqualOp>(ctx);
2295 patterns.add<ConvertTFAddOp>(ctx);
2296 patterns.add<ConvertTFAddV2Op>(ctx);
2297 patterns.add<ConvertTFAddNOp>(ctx);
2298 patterns.add<ConvertTFSubOp>(ctx);
2299 patterns.add<ConvertTFMulOp>(ctx);
2300 patterns.add<ConvertTFSquareOp>(ctx);
2301 patterns.add<ConvertTFSquaredDifferenceOp>(ctx);
2302 patterns.add<ConvertTFRoundOp>(ctx);
2303 patterns.add<ConvertTFFloorDivOp>(ctx);
2304 patterns.add<ConvertTFFloorModOp>(ctx);
2305 patterns.add<ConvertTFAssertOp>(ctx);
2306 patterns.add<ConvertTFMaximumOp>(ctx);
2307 patterns.add<ConvertTFMinimumOp>(ctx);
2308 patterns.add<ConvertTFRealDivOp>(ctx);
2309 patterns.add<ConvertTFArgMaxOp>(ctx);
2310 patterns.add<ConvertTFAvgPoolOp>(ctx);
2311 patterns.add<ConvertTFMaxPoolOp>(ctx);
2312 patterns.add<ConvertTFConcatV2Op>(ctx);
2313 patterns.add<ConvertTFReshapeOp>(ctx);
2314 patterns.add<ConvertTFRankOp>(ctx);
2315 patterns.add<ConvertTFShapeOp>(ctx);
2316 patterns.add<ConvertTFExpandDimsOp>(ctx);
2317 patterns.add<ConvertTFSqueezeOp>(ctx);
2318 patterns.add<ConvertTFFillOp>(ctx);
2319 patterns.add<ConvertTFConv2DOp>(ctx);
2320 patterns.add<ConvertTFDepthwiseConv2dNativeOp>(ctx);
2321 patterns.add<ConvertTFConv2DBackpropInputOp>(ctx);
2322 patterns.add<ConvertTFEluOp>(ctx);
2323 patterns.add<ConvertTFSoftmaxOp>(ctx);
2324 patterns.add<ConvertTFLogSoftmaxOp>(ctx);
2325 patterns.add<ConvertTFAllOp>(ctx);
2326 patterns.add<ConvertTFAnyOp>(ctx);
2327 patterns.add<ConvertTFMaxOp>(ctx);
2328 patterns.add<ConvertTFMinOp>(ctx);
2329 patterns.add<ConvertTFMeanOp>(ctx);
2330 patterns.add<ConvertTFProdOp>(ctx);
2331 patterns.add<ConvertTFSumOp>(ctx);
2332 patterns.add<ConvertTFFusedBatchNormOp>(ctx);
2333 patterns.add<ConvertTFFusedBatchNormV3Op>(ctx);
2334 patterns.add<ConvertTFBiasAddOp>(ctx);
2335 patterns.add<ConvertTFSplitOp>(ctx);
2336 patterns.add<ConvertTFSplitVOp>(ctx);
2337 patterns.add<ConvertTFPackOp>(ctx);
2338 patterns.add<ConvertTFUnpackOp>(ctx);
2339 patterns.add<ConvertTFTransposeOp>(ctx);
2340 patterns.add<ConvertTFTileOp>(ctx);
2341 patterns.add<ConvertTFSliceOp>(ctx);
2342 patterns.add<ConvertTFStridedSliceOp>(ctx);
2343 patterns.add<ConvertTFLessOp>(ctx);
2344 patterns.add<ConvertTFLessEqualOp>(ctx);
2345 patterns.add<ConvertTFPadOp>(ctx);
2346 patterns.add<ConvertTFResizeBilinearOp>(ctx);
2347 patterns.add<ConvertTFResizeNearestNeighborOp>(ctx);
2348 patterns.add<ConvertTFGatherOp>(ctx);
2349 patterns.add<ConvertTFGatherV2Op>(ctx);
2350 patterns.add<ConvertTFGatherNdOp>(ctx);
2351 patterns.add<ConvertTFSelectV2Op>(ctx);
2352 patterns.add<ConvertTFSpaceToDepthOp>(ctx);
2353 patterns.add<ConvertTFDepthToSpaceOp>(ctx);
2354 patterns.add<ConvertTFSpaceToBatchNDOp>(ctx);
2355 patterns.add<ConvertTFBatchToSpaceNDOp>(ctx);
2356 patterns.add<ConvertTFZerosLikeOp>(ctx);
2357 patterns.add<ConvertTFSigmoidOp>(ctx);
2358 patterns.add<ConvertTFTanhOp>(ctx);
2359 patterns.add<ConvertTFLeakyReluOp>(ctx);
2360 patterns.add<ConvertTFNegOp>(ctx);
2361 patterns.add<ConvertTFStopGradientOp>(ctx);
2362 patterns.add<ConvertTFReverseV2Op>(ctx);
2363 patterns.add<ConvertTFFakeQuantWithMinMaxArgsOp>(ctx);
2364 patterns.add<ConvertTFFakeQuantWithMinMaxVarsOp>(ctx);
2365 patterns.add<ConvertTFLeftShiftOp>(ctx);
2366 patterns.add<ConvertTFRightShiftOp>(ctx);
2367 patterns.add<ConvertTFOneHotOp>(ctx);
2368 patterns.add<ConvertTFBatchMatMulV2Op>(ctx);
2369 }
2370
2371 // Creates an instance of the TensorFlow dialect LegalizeTF pass.
createLegalizeTFPass()2372 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeTFPass() {
2373 return std::make_unique<LegalizeTF>();
2374 }
2375
2376 } // namespace tosa
2377
2378 } // namespace mlir
2379