xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_transpose.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <utility>
20 
21 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
22 #include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
26 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"  // from @llvm-project
27 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 #define GEN_PASS_CLASSES
34 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
35 
36 using llvm::SmallVector;
37 using mlir::Attribute;
38 using mlir::dyn_cast;
39 using mlir::failure;
40 using mlir::MLIRContext;
41 using mlir::Operation;
42 using mlir::PatternRewriter;
43 using mlir::success;
44 using mlir::Value;
45 using mlir::arith::ConstantIndexOp;
46 using mlir::gml_st::LoopOp;
47 using mlir::linalg::GenericOp;
48 using mlir::linalg::LinalgTilingOptions;
49 
50 /// Returns true if the operation is a GenericOp implementing a transposition.
51 // TODO(diegocaballero): Move it to MLIR core?
IsTransposeGenericOp(Operation * op)52 bool IsTransposeGenericOp(Operation *op) {
53   // Check that op is a generic op and has at least 2 dimensions.
54   auto generic_op = dyn_cast<GenericOp>(op);
55   if (!generic_op) return false;
56   if (generic_op.getNumLoops() < 2) return false;
57 
58   // Check whether the body has only one operation (yield op). Transpose ops
59   // fused with any other operations are not supported for now.
60   mlir::Block *body = generic_op.getBody();
61   if (body->empty() || body->begin() != std::prev(body->end())) return false;
62   auto yield_op = dyn_cast<mlir::linalg::YieldOp>(body->back());
63   if (!yield_op || (yield_op.getNumOperands() != 1)) return false;
64 
65   // Check input and output.
66   if ((generic_op.getNumInputs() != 1) || (generic_op.getNumOutputs() != 1))
67     return false;
68 
69   // Check that input is yielded.
70   if (generic_op.getTiedBlockArgument(generic_op.getInputOperand(0)) !=
71       yield_op.getOperand(0))
72     return false;
73 
74   // Check parallel iterators.
75   auto iterator_types = generic_op.iterator_types();
76   if (std::any_of(
77           iterator_types.begin(), iterator_types.end(),
78           [](Attribute attr) { return !mlir::isParallelIterator(attr); }))
79     return false;
80 
81   // Check that the two indexing maps are a permutation.
82   auto indexing_maps = generic_op.getIndexingMapsArray();
83   if (indexing_maps.size() != 2) return false;
84   return (indexing_maps[0].isIdentity() && indexing_maps[1].isPermutation()) ||
85          (indexing_maps[0].isPermutation() && indexing_maps[1].isIdentity());
86 }
87 
88 struct TileTransposePattern : public mlir::OpRewritePattern<GenericOp> {
TileTransposePatterntensorflow::__anon7bbabe600111::TileTransposePattern89   TileTransposePattern(LinalgTilingOptions options, MLIRContext *context,
90                        mlir::PatternBenefit benefit = 1)
91       : mlir::OpRewritePattern<GenericOp>(context, benefit), options(options) {}
92 
matchAndRewritetensorflow::__anon7bbabe600111::TileTransposePattern93   mlir::LogicalResult matchAndRewrite(
94       GenericOp linalg_op, PatternRewriter &rewriter) const override {
95     if (hasTransformationAttr(linalg_op)) return failure();
96     if (!IsTransposeGenericOp(linalg_op)) return failure();
97 
98     auto tiled_linalg_op =
99         mlir::gml_st::tileLinalgOp(rewriter, linalg_op, options);
100     if (failed(tiled_linalg_op) || tiled_linalg_op.getValue().loops.empty())
101       return failure();
102 
103     auto tiled_loop =
104         mlir::dyn_cast<LoopOp>(*tiled_linalg_op.getValue().loops.front());
105     if (!tiled_loop) return failure();
106 
107     tiled_loop->walk(
108         [&](GenericOp tiledOp) { setTransformationAttr(rewriter, tiledOp); });
109 
110     rewriter.replaceOp(linalg_op, tiled_loop->getResults());
111     return success();
112   }
113 
114  private:
115   LinalgTilingOptions options;
116 };
117 
118 struct TileTransposePass : public TileTransposeBase<TileTransposePass> {
runOnOperationtensorflow::__anon7bbabe600111::TileTransposePass119   void runOnOperation() override {
120     auto get_tile_size = [&](mlir::OpBuilder b, Operation *op) {
121       auto generic_op = llvm::cast<GenericOp>(op);
122       unsigned num_loops = generic_op.getNumLoops();
123       assert(num_loops >= 2 && "Expect two or more dimension in transpose op");
124 
125       // Compute the tile sizes for the 2-D vectorization of the transpose. We
126       // pick eight as default vectorization factor for both dimensions since
127       // it's the most performant AVX2 pattern for now. We pick the contiguous
128       // dimension of the input as first vector dimension and the contiguous
129       // dimension of the output as second vector dimension. This will maximize
130       // contiguous vector loads/stores and minimize insert/extract/gather/
131       // scatter operations.
132       SmallVector<Value> tiles(num_loops,
133                                b.create<ConstantIndexOp>(op->getLoc(), 1));
134       auto indexing_maps = generic_op.getIndexingMapsArray();
135       unsigned last_dim = num_loops - 1;
136       unsigned vec_factor0 = 8, vec_factor1 = 8;
137       unsigned vec_dim0 = indexing_maps[0].getDimPosition(last_dim);
138       unsigned vec_dim1 = indexing_maps[1].getDimPosition(last_dim);
139 
140       // If the contiguous dimensions of both input and output are not
141       // transposed (i.e, they are the same), we vectorize only that dimension.
142       // That transpose case doesn't require intra-register transposition but
143       // just copying a set of contiguous sub-buffers from the input to the
144       // output tensor. Vectorizing a second dimension would increase too much
145       // the memory pressure for no reason.
146       if (vec_dim0 == vec_dim1) {
147         tiles[vec_dim0] = b.create<ConstantIndexOp>(op->getLoc(), vec_factor0);
148       } else {
149         tiles[vec_dim0] = b.create<ConstantIndexOp>(op->getLoc(), vec_factor0);
150         tiles[vec_dim1] = b.create<ConstantIndexOp>(op->getLoc(), vec_factor1);
151       }
152 
153       return tiles;
154     };
155 
156     auto func = getOperation();
157     auto tiling_options =
158         LinalgTilingOptions().setTileSizeComputationFunction(get_tile_size);
159 
160     mlir::RewritePatternSet patterns(func.getContext());
161     patterns.add<TileTransposePattern>(tiling_options, patterns.getContext());
162     if (failed(mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
163       signalPassFailure();
164     }
165 
166     // Ensure we drop the marker in the end.
167     func.walk([](GenericOp op) { removeTransformationAttr(op); });
168   }
169 };
170 
171 }  // namespace
172 
173 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateTileTransposePass()174 CreateTileTransposePass() {
175   return std::make_unique<TileTransposePass>();
176 }
177 
178 }  // namespace tensorflow
179