xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_reduction.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 #include <vector>
19 
20 #include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/Linalg/Passes.h"
25 #include "mlir/Dialect/Tensor/IR/Tensor.h"
26 #include "mlir/Dialect/Tensor/Utils/Utils.h"
27 #include "mlir/Dialect/Utils/StaticValueUtils.h"
28 #include "mlir/IR/BlockAndValueMapping.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
31 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 #define GEN_PASS_CLASSES
37 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
38 
39 using llvm::makeArrayRef;
40 using mlir::BlockAndValueMapping;
41 using mlir::dyn_cast;
42 using mlir::failure;
43 using mlir::FailureOr;
44 using mlir::Location;
45 using mlir::LogicalResult;
46 using mlir::MLIRContext;
47 using mlir::OpBuilder;
48 using mlir::Operation;
49 using mlir::OpRewritePattern;
50 using mlir::PatternRewriter;
51 using mlir::RankedTensorType;
52 using mlir::ShapedType;
53 using mlir::SmallVector;
54 using mlir::success;
55 using mlir::Value;
56 using mlir::ValueRange;
57 using mlir::arith::ConstantIndexOp;
58 using mlir::gml_st::LoopOp;
59 using mlir::linalg::FillOp;
60 using mlir::linalg::GenericOp;
61 using mlir::linalg::InitTensorOp;
62 using mlir::linalg::LinalgOp;
63 using mlir::linalg::LinalgTilingOptions;
64 using mlir::tensor::ExpandShapeOp;
65 using mlir::tensor::ExtractSliceOp;
66 
67 // Match 1D or 2D reduction.
isCanonicalizedReduction(Operation * op)68 bool isCanonicalizedReduction(Operation *op) {
69   auto reduction = mlir::dyn_cast<GenericOp>(op);
70   if (!reduction) return false;
71 
72   if (reduction.getNumOutputs() != 1) return false;
73   if (reduction.getNumLoops() > 2) return false;
74   return reduction.getNumReductionLoops() == 1;
75 }
76 
77 // Tiles a GenericOp that models a 2D row or column reduction.
78 struct RowOrColumnReductionTilingPattern : public OpRewritePattern<GenericOp> {
RowOrColumnReductionTilingPatterntensorflow::__anonf12ed84e0111::RowOrColumnReductionTilingPattern79   RowOrColumnReductionTilingPattern(const LinalgTilingOptions &options,
80                                     MLIRContext *context,
81                                     mlir::PatternBenefit benefit = 1)
82       : OpRewritePattern<GenericOp>(context, benefit), options(options) {}
83 
matchAndRewritetensorflow::__anonf12ed84e0111::RowOrColumnReductionTilingPattern84   LogicalResult matchAndRewrite(GenericOp linalg_op,
85                                 PatternRewriter &rewriter) const override {
86     if (hasTransformationAttr(linalg_op)) return failure();
87     if (!isCanonicalizedReduction(linalg_op)) return failure();
88 
89     if (linalg_op.getNumOutputs() != 1) return failure();
90     if (linalg_op.getNumLoops() != 2) return failure();
91 
92     auto tiled_op = mlir::gml_st::tileLinalgOp(rewriter, linalg_op, options);
93     if (failed(tiled_op)) return failure();
94 
95     tiled_op->loops.front()->walk(
96         [&](LinalgOp tOp) { setTransformationAttr(rewriter, tOp); });
97 
98     rewriter.replaceOp(linalg_op, tiled_op->tensorResults);
99     return success();
100   }
101 
102  private:
103   LinalgTilingOptions options;
104 };
105 
106 // Rewrites a 1D reduction for vectorization. Matches `linalg.generic` that
107 // combines elements of tensor<?xELEM_TYPE> into tensor<ELEM_TYPE> and then
108 // creates a perfectly-tilable loop to reduce tensor<?xELEM_TYPE> ->
109 // tensor<VECTOR_SIZExELEM_TYPE> and an additional `linalg.generic` that reduces
110 // tensor<VECTOR_SIZExELEM_TYPE> to tensor<ELEM_TYPE>.
111 //
112 // Example:
113 //
114 // %sum = linalg.generic {
115 //   indexing_maps = [affine_map<(d0) -> (d0)>,
116 //                    affine_map<(d0) -> ()>],
117 //   iterator_types = ["reduction"]}
118 //   ins(%input : tensor<?xf32>)
119 //   outs(%fill : tensor<f32>) {
120 // ^bb0(%in: f32, %out: f32):
121 //   %add = arith.addf %in, %out : f32
122 //   linalg.yield %add : f32
123 // } -> tensor<f32>
124 //
125 // will be rewritten as
126 //
127 // %vector_result = gml_st.loop (%i)
128 //     = (%c0) to (%TILABLE_UB) step (%vector_size)
129 //     ins (%input_ = %input: tensor<?xf32>)
130 //     outs (%tmp_result_ = %tmp_result: tensor<VECTOR_SIZExf32>)
131 //     iterators["reduction"] {
132 //   %tile = tensor.extract_slice %arg2[%i] [%TILE_SIZE] [1]
133 //     : tensor<?xf32> to tensor<TILE_SIZExf32>
134 //   %tile_reshape = tensor.expand_shape %tile [[0, 1]]
135 //     : tensor<VECTOR_SIZExf32> into tensor<1xVECTOR_SIZExf32>
136 //   %combine = linalg.generic ins(%tile_reshape : tensor<1xVECTOR_SIZExf32>)
137 //     outs(%tmp_result_ : tensor<VECTOR_SIZExf32>) -> tensor<VECTOR_SIZExf32>
138 //   linalg.yield %combine : tensor<VECTOR_SIZExf32>
139 // }
140 // %horizontal_reduce = linalg.generic
141 //   ins(%vector_result : tensor<VECTOR_SIZExf32>)
142 //   outs(%fill : tensor<f32>) -> tensor<f32> // combiner only
143 // %result = gml_st.loop (%i)
144 //     = (%TILABLE_UB) to (%INPUT_SIZE) step (%vector_size)
145 //     ins (%input_ = %input: tensor<?xf32>)
146 //     outs (%tmp_result_ = %horizontal_reduce: tensor<f32>)
147 //     iterators["reduction"] {
148 //   linalg.generic // reduces the tail
149 // }
150 //
151 // This is necessary to push horizontal reduction to the later stage.
152 struct OneDimReductionTilingPattern : public OpRewritePattern<GenericOp> {
OneDimReductionTilingPatterntensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern153   OneDimReductionTilingPattern(int64_t vector_size, int64_t tile_size,
154                                mlir::MLIRContext *context,
155                                mlir::PatternBenefit benefit = 1)
156       : OpRewritePattern<GenericOp>(context, benefit),
157         vector_size(vector_size),
158         tile_size(tile_size) {}
159 
matchAndRewritetensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern160   LogicalResult matchAndRewrite(GenericOp linalg_op,
161                                 PatternRewriter &rewriter) const override {
162     if (hasTransformationAttr(linalg_op)) return failure();
163     if (!isCanonicalizedReduction(linalg_op)) return failure();
164 
165     // Check if all inputs have a 1D identity map.
166     if (linalg_op.getNumLoops() != 1) return failure();
167     auto indexing_maps = linalg_op.getIndexingMapsArray();
168     for (auto affine_map : makeArrayRef(indexing_maps).drop_back()) {
169       if (!affine_map.isIdentity()) return failure();
170     }
171 
172     Location loc = linalg_op.getLoc();
173     Value input = linalg_op.getInputOperand(0)->get();
174     // All inputs have the same size because of identity maps for indexing.
175     SmallVector<Value> inputs = linalg_op.inputs();
176     Value input_size = rewriter.create<mlir::tensor::DimOp>(loc, input, 0);
177 
178     auto fill_op = linalg_op.outputs().front().getDefiningOp<FillOp>();
179     auto init_op = fill_op.output().getDefiningOp<InitTensorOp>();
180 
181     auto neutral_value = fill_op.value();
182     auto element_type = init_op.getType().getElementType();
183 
184     Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
185     Value tile_size_value = rewriter.create<ConstantIndexOp>(loc, tile_size);
186     Value new_init = rewriter.create<InitTensorOp>(loc, ValueRange{},
187                                                    vector_size, element_type);
188     Value new_fill =
189         rewriter.create<FillOp>(loc, fill_op.value(), new_init).result();
190 
191     llvm::Optional<Value> tilable_bound_or =
192         getTilableBound(rewriter, loc, zero, input_size, tile_size_value);
193     Value tilable_bound =
194         tilable_bound_or.has_value() ? *tilable_bound_or : input_size;
195 
196     GenericOp tiled_reduction;
197     auto perfectly_tiled_loop = rewriter.create<LoopOp>(
198         loc, makeArrayRef(zero), makeArrayRef(tilable_bound),
199         makeArrayRef(tile_size_value), inputs, makeArrayRef(new_fill),
200         rewriter.getStrArrayAttr(mlir::getReductionIteratorTypeName()),
201         [&](OpBuilder &b, Location nested_loc, ValueRange ivs,
202             ValueRange inputs, ValueRange outputs) {
203           SmallVector<Value, 2> reshaped_tiled_inputs =
204               TileAndReshapeInputTensors(b, nested_loc, ivs, inputs,
205                                          neutral_value, input_size,
206                                          tile_size_value);
207           // Create `linalg.generic` to combine
208           // `tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE> input with
209           // the `tensor<VECTOR_SIZExELEM_TYPE>` output.
210           SmallVector<mlir::StringRef, 2> iter_types{
211               mlir::getReductionIteratorTypeName(),
212               mlir::getParallelIteratorTypeName()};
213           SmallVector<mlir::AffineMap, 2> indexing_maps(
214               inputs.size(), rewriter.getMultiDimIdentityMap(2));
215           indexing_maps.push_back(
216               mlir::AffineMap::get(2, 0, b.getAffineDimExpr(1)));
217           tiled_reduction = b.create<GenericOp>(
218               nested_loc, outputs[0].getType(), reshaped_tiled_inputs,
219               makeArrayRef({outputs[0]}), indexing_maps, iter_types,
220               /*bodyBuild=*/nullptr);
221           mlir::Region &region = tiled_reduction.region();
222           OpBuilder::InsertionGuard g(rewriter);
223           rewriter.cloneRegionBefore(linalg_op.region(), region, region.end());
224           b.create<mlir::gml_st::YieldOp>(nested_loc,
225                                           tiled_reduction.getResult(0));
226         });
227     // Create `linalg.generic` to reduce
228     // tensor<VECTOR_SIZExELEM_TYPE>->tensor<ELEM_TYPE>.
229     auto horizontal_reduction_or = ReduceVectorIntoOutput(
230         rewriter, linalg_op, perfectly_tiled_loop.getResult(0));
231     if (failed(horizontal_reduction_or)) return failure();
232     auto horizontal_reduction = horizontal_reduction_or.getValue();
233     Value result = horizontal_reduction->getResult(0);
234 
235     // If the loop was not perfectly tiled, then we have to combine
236     // `horizontal_reduction` with the elements in the `tail`.
237     if (tilable_bound_or.has_value()) {
238       auto final_reduction = rewriter.create<LoopOp>(
239           loc, tilable_bound, input_size, tile_size_value, inputs,
240           makeArrayRef(result),
241           rewriter.getStrArrayAttr(mlir::getReductionIteratorTypeName()),
242           [&](OpBuilder &b, Location nested_loc, ValueRange ivs,
243               ValueRange inputs, ValueRange outputs) {
244             BlockAndValueMapping bvm;
245             mlir::AffineExpr sym0, sym1;
246             bindSymbols(b.getContext(), sym0, sym1);
247             auto diff_map = mlir::AffineMap::get(0, 2, {sym1 - sym0});
248 
249             Value one = b.create<ConstantIndexOp>(nested_loc, 1);
250             auto size = b.createOrFold<mlir::AffineApplyOp>(
251                 nested_loc, diff_map, ValueRange{tilable_bound, input_size});
252             std::vector<Value> sliced_inputs;
253             sliced_inputs.reserve(inputs.size());
254             for (Value input : inputs) {
255               sliced_inputs.push_back(
256                   b.create<ExtractSliceOp>(nested_loc, input, ivs, size, one));
257             }
258             bvm.map(linalg_op.inputs(), sliced_inputs);
259             bvm.map(linalg_op.outputs(), outputs);
260             auto new_linalg_op = b.clone(*linalg_op.getOperation(), bvm);
261             setTransformationAttr(b, new_linalg_op);
262             b.create<mlir::gml_st::YieldOp>(nested_loc,
263                                             new_linalg_op->getResult(0));
264           });
265       result = final_reduction.getResult(0);
266     }
267     rewriter.replaceOp(linalg_op, result);
268 
269     perfectly_tiled_loop->walk(
270         [&](GenericOp op) { setTransformationAttr(rewriter, op); });
271     setTransformationAttr(rewriter, horizontal_reduction);
272     return success();
273   }
274 
275  private:
276   // Computes an upper bound that can be perfectly tiled. Return llvm::None, if
277   // the loop is already perfectly tiled.
getTilableBoundtensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern278   mlir::Optional<Value> getTilableBound(OpBuilder &b, Location loc, Value lb,
279                                         Value ub, Value step) const {
280     auto lb_int = getConstantIntValue(lb);
281     auto ub_int = getConstantIntValue(ub);
282     auto step_int = getConstantIntValue(step);
283 
284     // No specialization necessary if step already divides upper bound evenly.
285     if (lb_int && ub_int && step_int && (*ub_int - *lb_int) % *step_int == 0)
286       return llvm::None;
287     // No specialization necessary if step size is 1.
288     if (mlir::isConstantIntValue(step, 1)) return llvm::None;
289     mlir::AffineExpr sym0, sym1, sym2;
290     bindSymbols(b.getContext(), sym0, sym1, sym2);
291 
292     // New upper bound: %ub - (%ub - %lb) mod %step
293     auto mod_map = mlir::AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
294     return {b.createOrFold<mlir::AffineApplyOp>(loc, mod_map,
295                                                 ValueRange{lb, ub, step})};
296   }
297 
298   // Tiles, pads and reshapes every input argument of type tensor<?xELEM_TYPE>
299   // into tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>.
TileAndReshapeInputTensorstensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern300   SmallVector<Value, 2> TileAndReshapeInputTensors(
301       OpBuilder &b, Location nested_loc, ValueRange ivs, ValueRange inputs,
302       Value neutral_value, Value input_size, Value tile_size_value) const {
303     SmallVector<Value, 2> reshaped_tiled_inputs;
304 
305     SmallVector<mlir::ReassociationIndices> indices = {{0, 1}};
306     auto identity_1d_map = b.getMultiDimIdentityMap(1);
307     auto iv = ivs.front();
308 
309     mlir::OpFoldResult tile_size_fold = tile_size_value;
310     mlir::OpFoldResult input_size_fold = input_size;
311     auto tile_sizes = mlir::linalg::computeTileSizes(
312         b, nested_loc, tile_size_fold, input_size_fold);
313     for (auto input : inputs) {
314       // Extract slice of input.
315       Value slice = mlir::linalg::makeTiledShape(
316           b, nested_loc, input, tile_size_fold, identity_1d_map,
317           mlir::OpFoldResult(iv), input_size_fold, tile_sizes,
318           /*omitPartialTileCheck=*/true);
319       auto element_type = slice.getType().cast<ShapedType>().getElementType();
320 
321       // Reshape input tile to
322       // tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>.
323       Value expand_shape = b.create<ExpandShapeOp>(
324           nested_loc,
325           RankedTensorType::get({tile_size / vector_size, vector_size},
326                                 element_type),
327           slice, indices);
328       reshaped_tiled_inputs.push_back(expand_shape);
329     }
330     return reshaped_tiled_inputs;
331   }
332 
333   // Creates `linalg.generic` to reduce
334   // tensor<VECTOR_SIZExELEM_TYPE>->tensor<ELEM_TYPE>. To perform that we match
335   // the combiner in the original "untiled" linalg_op.
ReduceVectorIntoOutputtensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern336   FailureOr<GenericOp> ReduceVectorIntoOutput(PatternRewriter &rewriter,
337                                               LinalgOp linalg_op,
338                                               Value partial_result) const {
339     SmallVector<mlir::StringRef, 3> reduction_iter_type(
340         1, mlir::getReductionIteratorTypeName());
341     auto map = mlir::AffineMap::get(1, 0, llvm::None, rewriter.getContext());
342 
343     auto combiner_or = DetectCombiner(linalg_op);
344     if (failed(combiner_or)) return failure();
345     Operation *combiner = combiner_or.getValue();
346 
347     auto accumulator = rewriter.create<GenericOp>(
348         linalg_op.getLoc(), linalg_op->getResultTypes(),
349         makeArrayRef(partial_result),
350         makeArrayRef(linalg_op.getOutputOperand(0)->get()),
351         makeArrayRef({rewriter.getMultiDimIdentityMap(1), map}),
352         reduction_iter_type,
353         [&](OpBuilder &b, Location nested_loc, ValueRange args) {
354           BlockAndValueMapping bvm;
355           bvm.map(combiner->getOperands(), args);
356           Value result_val = b.clone(*combiner, bvm)->getResult(0);
357           b.create<mlir::linalg::YieldOp>(nested_loc, result_val);
358         });
359     return accumulator;
360   }
361 
362  private:
363   int64_t vector_size;
364   int64_t tile_size;
365 };
366 
367 struct TileReductionPass : public TileReductionBase<TileReductionPass> {
368   TileReductionPass() = default;
TileReductionPasstensorflow::__anonf12ed84e0111::TileReductionPass369   TileReductionPass(int64_t vector_size, int64_t reduction_1d_tile,
370                     llvm::ArrayRef<int64_t> reduction_2d_tiles) {
371     reduction_vector_size = vector_size;
372     reduction_1d_tile_size = reduction_1d_tile;
373     reduction_2d_tile_sizes = reduction_2d_tiles;
374   }
runOnOperationtensorflow::__anonf12ed84e0111::TileReductionPass375   void runOnOperation() override {
376     auto func = getOperation();
377     auto context = func.getContext();
378 
379     assert(reduction_1d_tile_size % reduction_vector_size == 0 &&
380            "Tile size for 1D reduction should be a multiple of vector size");
381     auto patterns =
382         mlir::linalg::getLinalgTilingCanonicalizationPatterns(context);
383     patterns.add<OneDimReductionTilingPattern>(
384         reduction_vector_size, reduction_1d_tile_size, patterns.getContext());
385 
386     assert(reduction_2d_tile_sizes.size() == 2 &&
387            "Tiling sizes for 2D reductions should have two elements");
388     patterns.add<RowOrColumnReductionTilingPattern>(
389         LinalgTilingOptions{}.setTileSizes(reduction_2d_tile_sizes),
390         patterns.getContext());
391     (void)mlir::applyPatternsAndFoldGreedily(func, std::move(patterns));
392 
393     // Ensure we drop the marker in the end.
394     func.walk([](LinalgOp op) { removeTransformationAttr(op); });
395   }
396 };
397 
398 }  // namespace
399 
400 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateTileReductionPass()401 CreateTileReductionPass() {
402   return std::make_unique<TileReductionPass>();
403 }
404 
405 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateTileReductionPass(int64_t reduction_vector_size,int64_t reduction_1d_tile_size,llvm::ArrayRef<int64_t> reduction_2d_tile_sizes)406 CreateTileReductionPass(int64_t reduction_vector_size,
407                         int64_t reduction_1d_tile_size,
408                         llvm::ArrayRef<int64_t> reduction_2d_tile_sizes) {
409   return std::make_unique<TileReductionPass>(
410       reduction_vector_size, reduction_1d_tile_size, reduction_2d_tile_sizes);
411 }
412 
413 }  // namespace tensorflow
414