1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Passes.h"
21 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
26 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
27
28 namespace tensorflow {
29 namespace {
30
31 #define GEN_PASS_CLASSES
32 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
33
34 using mlir::AffineMap;
35 using mlir::ConversionPatternRewriter;
36 using mlir::failure;
37 using mlir::LogicalResult;
38 using mlir::OpConversionPattern;
39 using mlir::OpRewritePattern;
40 using mlir::PatternRewriter;
41 using mlir::RankedTensorType;
42 using mlir::success;
43 using mlir::Type;
44 using mlir::TypeRange;
45 using mlir::Value;
46 using mlir::linalg::GenericOp;
47 using mlir::tensor::ExtractOp;
48 using mlir::tensor::FromElementsOp;
49
IsNotZeroRankTensor(RankedTensorType tensor_type)50 bool IsNotZeroRankTensor(RankedTensorType tensor_type) {
51 return !tensor_type || tensor_type.getRank() > 0;
52 }
53
54 /// A conversion patttern for detensoring Linalg ops.
55 struct DetensorizeLinalgOp : public OpConversionPattern<GenericOp> {
56 using OpConversionPattern<GenericOp>::OpConversionPattern;
57
matchAndRewritetensorflow::__anona19f27960111::DetensorizeLinalgOp58 LogicalResult matchAndRewrite(
59 GenericOp op, OpAdaptor /*adaptor*/,
60 ConversionPatternRewriter& rewriter) const override {
61 mlir::Location loc = op.getLoc();
62 mlir::SmallVector<AffineMap, 3> indexing_maps = op.getIndexingMapsArray();
63
64 mlir::SmallVector<Value, 3> inputs;
65 bool found_zero_dim_tensor = false;
66 for (auto& en : llvm::enumerate(op.getInputOperands())) {
67 auto tensor_type =
68 en.value()->get().getType().dyn_cast<RankedTensorType>();
69 if (IsNotZeroRankTensor(tensor_type)) {
70 inputs.push_back(en.value()->get());
71 continue;
72 }
73 found_zero_dim_tensor = true;
74 indexing_maps[en.index()] =
75 AffineMap::get(op.getNumLoops(), 0, llvm::None, op.getContext());
76 inputs.push_back(rewriter.create<ExtractOp>(loc, en.value()->get(),
77 mlir::ValueRange{}));
78 }
79 if (!found_zero_dim_tensor) return failure();
80
81 auto linalg_op = rewriter.create<GenericOp>(
82 loc, op.getResultTypes(), inputs, op.outputs(),
83 rewriter.getAffineMapArrayAttr(indexing_maps), op.iterator_types(),
84 mlir::StringAttr(), mlir::StringAttr());
85 mlir::Region& region = linalg_op.region();
86 rewriter.inlineRegionBefore(op.getBodyRegion(), region, region.end());
87 rewriter.replaceOp(op, linalg_op.getResults());
88 return success();
89 }
90 };
91
92 struct DetensorizeLinalgPass
93 : public DetensorizeLinalgBase<DetensorizeLinalgPass> {
94 DetensorizeLinalgPass() = default;
95
runOnOperationtensorflow::__anona19f27960111::DetensorizeLinalgPass96 void runOnOperation() override {
97 auto func = getOperation();
98 auto* context = &getContext();
99
100 mlir::ConversionTarget target(*context);
101 target.markUnknownOpDynamicallyLegal([](mlir::Operation*) { return true; });
102 target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
103 return llvm::all_of(TypeRange{op.inputs()}, [&](Type type) {
104 return IsNotZeroRankTensor(type.dyn_cast<RankedTensorType>());
105 });
106 });
107
108 // Detensorize.
109 mlir::RewritePatternSet patterns(context);
110 patterns.add<DetensorizeLinalgOp>(context);
111 if (failed(applyFullConversion(func, target, std::move(patterns))))
112 signalPassFailure();
113
114 // Canonicalize.
115 mlir::RewritePatternSet canonicalization_patterns(context);
116 FromElementsOp::getCanonicalizationPatterns(patterns, context);
117 if (failed(applyPatternsAndFoldGreedily(
118 func, std::move(canonicalization_patterns))))
119 signalPassFailure();
120 }
121 };
122
123 } // namespace
124
125 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDetensorizeLinalgPass()126 CreateDetensorizeLinalgPass() {
127 return std::make_unique<DetensorizeLinalgPass>();
128 }
129
130 } // namespace tensorflow
131