1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Passes.h"
21 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
22 #include "mlir/Dialect/Tensor/IR/Tensor.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
26 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 #define GEN_PASS_CLASSES
32 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
33 
34 using mlir::AffineMap;
35 using mlir::ConversionPatternRewriter;
36 using mlir::failure;
37 using mlir::LogicalResult;
38 using mlir::OpConversionPattern;
39 using mlir::OpRewritePattern;
40 using mlir::PatternRewriter;
41 using mlir::RankedTensorType;
42 using mlir::success;
43 using mlir::Type;
44 using mlir::TypeRange;
45 using mlir::Value;
46 using mlir::linalg::GenericOp;
47 using mlir::tensor::ExtractOp;
48 using mlir::tensor::FromElementsOp;
49 
IsNotZeroRankTensor(RankedTensorType tensor_type)50 bool IsNotZeroRankTensor(RankedTensorType tensor_type) {
51   return !tensor_type || tensor_type.getRank() > 0;
52 }
53 
54 /// A conversion patttern for detensoring Linalg ops.
55 struct DetensorizeLinalgOp : public OpConversionPattern<GenericOp> {
56   using OpConversionPattern<GenericOp>::OpConversionPattern;
57 
matchAndRewritetensorflow::__anona19f27960111::DetensorizeLinalgOp58   LogicalResult matchAndRewrite(
59       GenericOp op, OpAdaptor /*adaptor*/,
60       ConversionPatternRewriter& rewriter) const override {
61     mlir::Location loc = op.getLoc();
62     mlir::SmallVector<AffineMap, 3> indexing_maps = op.getIndexingMapsArray();
63 
64     mlir::SmallVector<Value, 3> inputs;
65     bool found_zero_dim_tensor = false;
66     for (auto& en : llvm::enumerate(op.getInputOperands())) {
67       auto tensor_type =
68           en.value()->get().getType().dyn_cast<RankedTensorType>();
69       if (IsNotZeroRankTensor(tensor_type)) {
70         inputs.push_back(en.value()->get());
71         continue;
72       }
73       found_zero_dim_tensor = true;
74       indexing_maps[en.index()] =
75           AffineMap::get(op.getNumLoops(), 0, llvm::None, op.getContext());
76       inputs.push_back(rewriter.create<ExtractOp>(loc, en.value()->get(),
77                                                   mlir::ValueRange{}));
78     }
79     if (!found_zero_dim_tensor) return failure();
80 
81     auto linalg_op = rewriter.create<GenericOp>(
82         loc, op.getResultTypes(), inputs, op.outputs(),
83         rewriter.getAffineMapArrayAttr(indexing_maps), op.iterator_types(),
84         mlir::StringAttr(), mlir::StringAttr());
85     mlir::Region& region = linalg_op.region();
86     rewriter.inlineRegionBefore(op.getBodyRegion(), region, region.end());
87     rewriter.replaceOp(op, linalg_op.getResults());
88     return success();
89   }
90 };
91 
92 struct DetensorizeLinalgPass
93     : public DetensorizeLinalgBase<DetensorizeLinalgPass> {
94   DetensorizeLinalgPass() = default;
95 
runOnOperationtensorflow::__anona19f27960111::DetensorizeLinalgPass96   void runOnOperation() override {
97     auto func = getOperation();
98     auto* context = &getContext();
99 
100     mlir::ConversionTarget target(*context);
101     target.markUnknownOpDynamicallyLegal([](mlir::Operation*) { return true; });
102     target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
103       return llvm::all_of(TypeRange{op.inputs()}, [&](Type type) {
104         return IsNotZeroRankTensor(type.dyn_cast<RankedTensorType>());
105       });
106     });
107 
108     // Detensorize.
109     mlir::RewritePatternSet patterns(context);
110     patterns.add<DetensorizeLinalgOp>(context);
111     if (failed(applyFullConversion(func, target, std::move(patterns))))
112       signalPassFailure();
113 
114     // Canonicalize.
115     mlir::RewritePatternSet canonicalization_patterns(context);
116     FromElementsOp::getCanonicalizationPatterns(patterns, context);
117     if (failed(applyPatternsAndFoldGreedily(
118             func, std::move(canonicalization_patterns))))
119       signalPassFailure();
120   }
121 };
122 
123 }  // namespace
124 
125 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDetensorizeLinalgPass()126 CreateDetensorizeLinalgPass() {
127   return std::make_unique<DetensorizeLinalgPass>();
128 }
129 
130 }  // namespace tensorflow
131