1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <utility>
20
21 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
22 #include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
26 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
27 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
29
30 namespace tensorflow {
31 namespace {
32
33 #define GEN_PASS_CLASSES
34 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
35
36 using llvm::SmallVector;
37 using mlir::Attribute;
38 using mlir::dyn_cast;
39 using mlir::failure;
40 using mlir::MLIRContext;
41 using mlir::Operation;
42 using mlir::PatternRewriter;
43 using mlir::success;
44 using mlir::Value;
45 using mlir::arith::ConstantIndexOp;
46 using mlir::gml_st::LoopOp;
47 using mlir::linalg::GenericOp;
48 using mlir::linalg::LinalgTilingOptions;
49
50 /// Returns true if the operation is a GenericOp implementing a transposition.
51 // TODO(diegocaballero): Move it to MLIR core?
IsTransposeGenericOp(Operation * op)52 bool IsTransposeGenericOp(Operation *op) {
53 // Check that op is a generic op and has at least 2 dimensions.
54 auto generic_op = dyn_cast<GenericOp>(op);
55 if (!generic_op) return false;
56 if (generic_op.getNumLoops() < 2) return false;
57
58 // Check whether the body has only one operation (yield op). Transpose ops
59 // fused with any other operations are not supported for now.
60 mlir::Block *body = generic_op.getBody();
61 if (body->empty() || body->begin() != std::prev(body->end())) return false;
62 auto yield_op = dyn_cast<mlir::linalg::YieldOp>(body->back());
63 if (!yield_op || (yield_op.getNumOperands() != 1)) return false;
64
65 // Check input and output.
66 if ((generic_op.getNumInputs() != 1) || (generic_op.getNumOutputs() != 1))
67 return false;
68
69 // Check that input is yielded.
70 if (generic_op.getTiedBlockArgument(generic_op.getInputOperand(0)) !=
71 yield_op.getOperand(0))
72 return false;
73
74 // Check parallel iterators.
75 auto iterator_types = generic_op.iterator_types();
76 if (std::any_of(
77 iterator_types.begin(), iterator_types.end(),
78 [](Attribute attr) { return !mlir::isParallelIterator(attr); }))
79 return false;
80
81 // Check that the two indexing maps are a permutation.
82 auto indexing_maps = generic_op.getIndexingMapsArray();
83 if (indexing_maps.size() != 2) return false;
84 return (indexing_maps[0].isIdentity() && indexing_maps[1].isPermutation()) ||
85 (indexing_maps[0].isPermutation() && indexing_maps[1].isIdentity());
86 }
87
88 struct TileTransposePattern : public mlir::OpRewritePattern<GenericOp> {
TileTransposePatterntensorflow::__anon7bbabe600111::TileTransposePattern89 TileTransposePattern(LinalgTilingOptions options, MLIRContext *context,
90 mlir::PatternBenefit benefit = 1)
91 : mlir::OpRewritePattern<GenericOp>(context, benefit), options(options) {}
92
matchAndRewritetensorflow::__anon7bbabe600111::TileTransposePattern93 mlir::LogicalResult matchAndRewrite(
94 GenericOp linalg_op, PatternRewriter &rewriter) const override {
95 if (hasTransformationAttr(linalg_op)) return failure();
96 if (!IsTransposeGenericOp(linalg_op)) return failure();
97
98 auto tiled_linalg_op =
99 mlir::gml_st::tileLinalgOp(rewriter, linalg_op, options);
100 if (failed(tiled_linalg_op) || tiled_linalg_op.getValue().loops.empty())
101 return failure();
102
103 auto tiled_loop =
104 mlir::dyn_cast<LoopOp>(*tiled_linalg_op.getValue().loops.front());
105 if (!tiled_loop) return failure();
106
107 tiled_loop->walk(
108 [&](GenericOp tiledOp) { setTransformationAttr(rewriter, tiledOp); });
109
110 rewriter.replaceOp(linalg_op, tiled_loop->getResults());
111 return success();
112 }
113
114 private:
115 LinalgTilingOptions options;
116 };
117
118 struct TileTransposePass : public TileTransposeBase<TileTransposePass> {
runOnOperationtensorflow::__anon7bbabe600111::TileTransposePass119 void runOnOperation() override {
120 auto get_tile_size = [&](mlir::OpBuilder b, Operation *op) {
121 auto generic_op = llvm::cast<GenericOp>(op);
122 unsigned num_loops = generic_op.getNumLoops();
123 assert(num_loops >= 2 && "Expect two or more dimension in transpose op");
124
125 // Compute the tile sizes for the 2-D vectorization of the transpose. We
126 // pick eight as default vectorization factor for both dimensions since
127 // it's the most performant AVX2 pattern for now. We pick the contiguous
128 // dimension of the input as first vector dimension and the contiguous
129 // dimension of the output as second vector dimension. This will maximize
130 // contiguous vector loads/stores and minimize insert/extract/gather/
131 // scatter operations.
132 SmallVector<Value> tiles(num_loops,
133 b.create<ConstantIndexOp>(op->getLoc(), 1));
134 auto indexing_maps = generic_op.getIndexingMapsArray();
135 unsigned last_dim = num_loops - 1;
136 unsigned vec_factor0 = 8, vec_factor1 = 8;
137 unsigned vec_dim0 = indexing_maps[0].getDimPosition(last_dim);
138 unsigned vec_dim1 = indexing_maps[1].getDimPosition(last_dim);
139
140 // If the contiguous dimensions of both input and output are not
141 // transposed (i.e, they are the same), we vectorize only that dimension.
142 // That transpose case doesn't require intra-register transposition but
143 // just copying a set of contiguous sub-buffers from the input to the
144 // output tensor. Vectorizing a second dimension would increase too much
145 // the memory pressure for no reason.
146 if (vec_dim0 == vec_dim1) {
147 tiles[vec_dim0] = b.create<ConstantIndexOp>(op->getLoc(), vec_factor0);
148 } else {
149 tiles[vec_dim0] = b.create<ConstantIndexOp>(op->getLoc(), vec_factor0);
150 tiles[vec_dim1] = b.create<ConstantIndexOp>(op->getLoc(), vec_factor1);
151 }
152
153 return tiles;
154 };
155
156 auto func = getOperation();
157 auto tiling_options =
158 LinalgTilingOptions().setTileSizeComputationFunction(get_tile_size);
159
160 mlir::RewritePatternSet patterns(func.getContext());
161 patterns.add<TileTransposePattern>(tiling_options, patterns.getContext());
162 if (failed(mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
163 signalPassFailure();
164 }
165
166 // Ensure we drop the marker in the end.
167 func.walk([](GenericOp op) { removeTransformationAttr(op); });
168 }
169 };
170
171 } // namespace
172
173 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateTileTransposePass()174 CreateTileTransposePass() {
175 return std::make_unique<TileTransposePass>();
176 }
177
178 } // namespace tensorflow
179