1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18 #include <vector>
19
20 #include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/Linalg/Passes.h"
25 #include "mlir/Dialect/Tensor/IR/Tensor.h"
26 #include "mlir/Dialect/Tensor/Utils/Utils.h"
27 #include "mlir/Dialect/Utils/StaticValueUtils.h"
28 #include "mlir/IR/BlockAndValueMapping.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
31 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
32
33 namespace tensorflow {
34 namespace {
35
36 #define GEN_PASS_CLASSES
37 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
38
39 using llvm::makeArrayRef;
40 using mlir::BlockAndValueMapping;
41 using mlir::dyn_cast;
42 using mlir::failure;
43 using mlir::FailureOr;
44 using mlir::Location;
45 using mlir::LogicalResult;
46 using mlir::MLIRContext;
47 using mlir::OpBuilder;
48 using mlir::Operation;
49 using mlir::OpRewritePattern;
50 using mlir::PatternRewriter;
51 using mlir::RankedTensorType;
52 using mlir::ShapedType;
53 using mlir::SmallVector;
54 using mlir::success;
55 using mlir::Value;
56 using mlir::ValueRange;
57 using mlir::arith::ConstantIndexOp;
58 using mlir::gml_st::LoopOp;
59 using mlir::linalg::FillOp;
60 using mlir::linalg::GenericOp;
61 using mlir::linalg::InitTensorOp;
62 using mlir::linalg::LinalgOp;
63 using mlir::linalg::LinalgTilingOptions;
64 using mlir::tensor::ExpandShapeOp;
65 using mlir::tensor::ExtractSliceOp;
66
67 // Match 1D or 2D reduction.
isCanonicalizedReduction(Operation * op)68 bool isCanonicalizedReduction(Operation *op) {
69 auto reduction = mlir::dyn_cast<GenericOp>(op);
70 if (!reduction) return false;
71
72 if (reduction.getNumOutputs() != 1) return false;
73 if (reduction.getNumLoops() > 2) return false;
74 return reduction.getNumReductionLoops() == 1;
75 }
76
77 // Tiles a GenericOp that models a 2D row or column reduction.
78 struct RowOrColumnReductionTilingPattern : public OpRewritePattern<GenericOp> {
RowOrColumnReductionTilingPatterntensorflow::__anonf12ed84e0111::RowOrColumnReductionTilingPattern79 RowOrColumnReductionTilingPattern(const LinalgTilingOptions &options,
80 MLIRContext *context,
81 mlir::PatternBenefit benefit = 1)
82 : OpRewritePattern<GenericOp>(context, benefit), options(options) {}
83
matchAndRewritetensorflow::__anonf12ed84e0111::RowOrColumnReductionTilingPattern84 LogicalResult matchAndRewrite(GenericOp linalg_op,
85 PatternRewriter &rewriter) const override {
86 if (hasTransformationAttr(linalg_op)) return failure();
87 if (!isCanonicalizedReduction(linalg_op)) return failure();
88
89 if (linalg_op.getNumOutputs() != 1) return failure();
90 if (linalg_op.getNumLoops() != 2) return failure();
91
92 auto tiled_op = mlir::gml_st::tileLinalgOp(rewriter, linalg_op, options);
93 if (failed(tiled_op)) return failure();
94
95 tiled_op->loops.front()->walk(
96 [&](LinalgOp tOp) { setTransformationAttr(rewriter, tOp); });
97
98 rewriter.replaceOp(linalg_op, tiled_op->tensorResults);
99 return success();
100 }
101
102 private:
103 LinalgTilingOptions options;
104 };
105
106 // Rewrites a 1D reduction for vectorization. Matches `linalg.generic` that
107 // combines elements of tensor<?xELEM_TYPE> into tensor<ELEM_TYPE> and then
108 // creates a perfectly-tilable loop to reduce tensor<?xELEM_TYPE> ->
109 // tensor<VECTOR_SIZExELEM_TYPE> and an additional `linalg.generic` that reduces
110 // tensor<VECTOR_SIZExELEM_TYPE> to tensor<ELEM_TYPE>.
111 //
112 // Example:
113 //
114 // %sum = linalg.generic {
115 // indexing_maps = [affine_map<(d0) -> (d0)>,
116 // affine_map<(d0) -> ()>],
117 // iterator_types = ["reduction"]}
118 // ins(%input : tensor<?xf32>)
119 // outs(%fill : tensor<f32>) {
120 // ^bb0(%in: f32, %out: f32):
121 // %add = arith.addf %in, %out : f32
122 // linalg.yield %add : f32
123 // } -> tensor<f32>
124 //
125 // will be rewritten as
126 //
127 // %vector_result = gml_st.loop (%i)
128 // = (%c0) to (%TILABLE_UB) step (%vector_size)
129 // ins (%input_ = %input: tensor<?xf32>)
130 // outs (%tmp_result_ = %tmp_result: tensor<VECTOR_SIZExf32>)
131 // iterators["reduction"] {
132 // %tile = tensor.extract_slice %arg2[%i] [%TILE_SIZE] [1]
133 // : tensor<?xf32> to tensor<TILE_SIZExf32>
134 // %tile_reshape = tensor.expand_shape %tile [[0, 1]]
135 // : tensor<VECTOR_SIZExf32> into tensor<1xVECTOR_SIZExf32>
136 // %combine = linalg.generic ins(%tile_reshape : tensor<1xVECTOR_SIZExf32>)
137 // outs(%tmp_result_ : tensor<VECTOR_SIZExf32>) -> tensor<VECTOR_SIZExf32>
138 // linalg.yield %combine : tensor<VECTOR_SIZExf32>
139 // }
140 // %horizontal_reduce = linalg.generic
141 // ins(%vector_result : tensor<VECTOR_SIZExf32>)
142 // outs(%fill : tensor<f32>) -> tensor<f32> // combiner only
143 // %result = gml_st.loop (%i)
144 // = (%TILABLE_UB) to (%INPUT_SIZE) step (%vector_size)
145 // ins (%input_ = %input: tensor<?xf32>)
146 // outs (%tmp_result_ = %horizontal_reduce: tensor<f32>)
147 // iterators["reduction"] {
148 // linalg.generic // reduces the tail
149 // }
150 //
151 // This is necessary to push horizontal reduction to the later stage.
152 struct OneDimReductionTilingPattern : public OpRewritePattern<GenericOp> {
OneDimReductionTilingPatterntensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern153 OneDimReductionTilingPattern(int64_t vector_size, int64_t tile_size,
154 mlir::MLIRContext *context,
155 mlir::PatternBenefit benefit = 1)
156 : OpRewritePattern<GenericOp>(context, benefit),
157 vector_size(vector_size),
158 tile_size(tile_size) {}
159
matchAndRewritetensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern160 LogicalResult matchAndRewrite(GenericOp linalg_op,
161 PatternRewriter &rewriter) const override {
162 if (hasTransformationAttr(linalg_op)) return failure();
163 if (!isCanonicalizedReduction(linalg_op)) return failure();
164
165 // Check if all inputs have a 1D identity map.
166 if (linalg_op.getNumLoops() != 1) return failure();
167 auto indexing_maps = linalg_op.getIndexingMapsArray();
168 for (auto affine_map : makeArrayRef(indexing_maps).drop_back()) {
169 if (!affine_map.isIdentity()) return failure();
170 }
171
172 Location loc = linalg_op.getLoc();
173 Value input = linalg_op.getInputOperand(0)->get();
174 // All inputs have the same size because of identity maps for indexing.
175 SmallVector<Value> inputs = linalg_op.inputs();
176 Value input_size = rewriter.create<mlir::tensor::DimOp>(loc, input, 0);
177
178 auto fill_op = linalg_op.outputs().front().getDefiningOp<FillOp>();
179 auto init_op = fill_op.output().getDefiningOp<InitTensorOp>();
180
181 auto neutral_value = fill_op.value();
182 auto element_type = init_op.getType().getElementType();
183
184 Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
185 Value tile_size_value = rewriter.create<ConstantIndexOp>(loc, tile_size);
186 Value new_init = rewriter.create<InitTensorOp>(loc, ValueRange{},
187 vector_size, element_type);
188 Value new_fill =
189 rewriter.create<FillOp>(loc, fill_op.value(), new_init).result();
190
191 llvm::Optional<Value> tilable_bound_or =
192 getTilableBound(rewriter, loc, zero, input_size, tile_size_value);
193 Value tilable_bound =
194 tilable_bound_or.has_value() ? *tilable_bound_or : input_size;
195
196 GenericOp tiled_reduction;
197 auto perfectly_tiled_loop = rewriter.create<LoopOp>(
198 loc, makeArrayRef(zero), makeArrayRef(tilable_bound),
199 makeArrayRef(tile_size_value), inputs, makeArrayRef(new_fill),
200 rewriter.getStrArrayAttr(mlir::getReductionIteratorTypeName()),
201 [&](OpBuilder &b, Location nested_loc, ValueRange ivs,
202 ValueRange inputs, ValueRange outputs) {
203 SmallVector<Value, 2> reshaped_tiled_inputs =
204 TileAndReshapeInputTensors(b, nested_loc, ivs, inputs,
205 neutral_value, input_size,
206 tile_size_value);
207 // Create `linalg.generic` to combine
208 // `tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE> input with
209 // the `tensor<VECTOR_SIZExELEM_TYPE>` output.
210 SmallVector<mlir::StringRef, 2> iter_types{
211 mlir::getReductionIteratorTypeName(),
212 mlir::getParallelIteratorTypeName()};
213 SmallVector<mlir::AffineMap, 2> indexing_maps(
214 inputs.size(), rewriter.getMultiDimIdentityMap(2));
215 indexing_maps.push_back(
216 mlir::AffineMap::get(2, 0, b.getAffineDimExpr(1)));
217 tiled_reduction = b.create<GenericOp>(
218 nested_loc, outputs[0].getType(), reshaped_tiled_inputs,
219 makeArrayRef({outputs[0]}), indexing_maps, iter_types,
220 /*bodyBuild=*/nullptr);
221 mlir::Region ®ion = tiled_reduction.region();
222 OpBuilder::InsertionGuard g(rewriter);
223 rewriter.cloneRegionBefore(linalg_op.region(), region, region.end());
224 b.create<mlir::gml_st::YieldOp>(nested_loc,
225 tiled_reduction.getResult(0));
226 });
227 // Create `linalg.generic` to reduce
228 // tensor<VECTOR_SIZExELEM_TYPE>->tensor<ELEM_TYPE>.
229 auto horizontal_reduction_or = ReduceVectorIntoOutput(
230 rewriter, linalg_op, perfectly_tiled_loop.getResult(0));
231 if (failed(horizontal_reduction_or)) return failure();
232 auto horizontal_reduction = horizontal_reduction_or.getValue();
233 Value result = horizontal_reduction->getResult(0);
234
235 // If the loop was not perfectly tiled, then we have to combine
236 // `horizontal_reduction` with the elements in the `tail`.
237 if (tilable_bound_or.has_value()) {
238 auto final_reduction = rewriter.create<LoopOp>(
239 loc, tilable_bound, input_size, tile_size_value, inputs,
240 makeArrayRef(result),
241 rewriter.getStrArrayAttr(mlir::getReductionIteratorTypeName()),
242 [&](OpBuilder &b, Location nested_loc, ValueRange ivs,
243 ValueRange inputs, ValueRange outputs) {
244 BlockAndValueMapping bvm;
245 mlir::AffineExpr sym0, sym1;
246 bindSymbols(b.getContext(), sym0, sym1);
247 auto diff_map = mlir::AffineMap::get(0, 2, {sym1 - sym0});
248
249 Value one = b.create<ConstantIndexOp>(nested_loc, 1);
250 auto size = b.createOrFold<mlir::AffineApplyOp>(
251 nested_loc, diff_map, ValueRange{tilable_bound, input_size});
252 std::vector<Value> sliced_inputs;
253 sliced_inputs.reserve(inputs.size());
254 for (Value input : inputs) {
255 sliced_inputs.push_back(
256 b.create<ExtractSliceOp>(nested_loc, input, ivs, size, one));
257 }
258 bvm.map(linalg_op.inputs(), sliced_inputs);
259 bvm.map(linalg_op.outputs(), outputs);
260 auto new_linalg_op = b.clone(*linalg_op.getOperation(), bvm);
261 setTransformationAttr(b, new_linalg_op);
262 b.create<mlir::gml_st::YieldOp>(nested_loc,
263 new_linalg_op->getResult(0));
264 });
265 result = final_reduction.getResult(0);
266 }
267 rewriter.replaceOp(linalg_op, result);
268
269 perfectly_tiled_loop->walk(
270 [&](GenericOp op) { setTransformationAttr(rewriter, op); });
271 setTransformationAttr(rewriter, horizontal_reduction);
272 return success();
273 }
274
275 private:
276 // Computes an upper bound that can be perfectly tiled. Return llvm::None, if
277 // the loop is already perfectly tiled.
getTilableBoundtensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern278 mlir::Optional<Value> getTilableBound(OpBuilder &b, Location loc, Value lb,
279 Value ub, Value step) const {
280 auto lb_int = getConstantIntValue(lb);
281 auto ub_int = getConstantIntValue(ub);
282 auto step_int = getConstantIntValue(step);
283
284 // No specialization necessary if step already divides upper bound evenly.
285 if (lb_int && ub_int && step_int && (*ub_int - *lb_int) % *step_int == 0)
286 return llvm::None;
287 // No specialization necessary if step size is 1.
288 if (mlir::isConstantIntValue(step, 1)) return llvm::None;
289 mlir::AffineExpr sym0, sym1, sym2;
290 bindSymbols(b.getContext(), sym0, sym1, sym2);
291
292 // New upper bound: %ub - (%ub - %lb) mod %step
293 auto mod_map = mlir::AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)});
294 return {b.createOrFold<mlir::AffineApplyOp>(loc, mod_map,
295 ValueRange{lb, ub, step})};
296 }
297
298 // Tiles, pads and reshapes every input argument of type tensor<?xELEM_TYPE>
299 // into tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>.
TileAndReshapeInputTensorstensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern300 SmallVector<Value, 2> TileAndReshapeInputTensors(
301 OpBuilder &b, Location nested_loc, ValueRange ivs, ValueRange inputs,
302 Value neutral_value, Value input_size, Value tile_size_value) const {
303 SmallVector<Value, 2> reshaped_tiled_inputs;
304
305 SmallVector<mlir::ReassociationIndices> indices = {{0, 1}};
306 auto identity_1d_map = b.getMultiDimIdentityMap(1);
307 auto iv = ivs.front();
308
309 mlir::OpFoldResult tile_size_fold = tile_size_value;
310 mlir::OpFoldResult input_size_fold = input_size;
311 auto tile_sizes = mlir::linalg::computeTileSizes(
312 b, nested_loc, tile_size_fold, input_size_fold);
313 for (auto input : inputs) {
314 // Extract slice of input.
315 Value slice = mlir::linalg::makeTiledShape(
316 b, nested_loc, input, tile_size_fold, identity_1d_map,
317 mlir::OpFoldResult(iv), input_size_fold, tile_sizes,
318 /*omitPartialTileCheck=*/true);
319 auto element_type = slice.getType().cast<ShapedType>().getElementType();
320
321 // Reshape input tile to
322 // tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>.
323 Value expand_shape = b.create<ExpandShapeOp>(
324 nested_loc,
325 RankedTensorType::get({tile_size / vector_size, vector_size},
326 element_type),
327 slice, indices);
328 reshaped_tiled_inputs.push_back(expand_shape);
329 }
330 return reshaped_tiled_inputs;
331 }
332
333 // Creates `linalg.generic` to reduce
334 // tensor<VECTOR_SIZExELEM_TYPE>->tensor<ELEM_TYPE>. To perform that we match
335 // the combiner in the original "untiled" linalg_op.
ReduceVectorIntoOutputtensorflow::__anonf12ed84e0111::OneDimReductionTilingPattern336 FailureOr<GenericOp> ReduceVectorIntoOutput(PatternRewriter &rewriter,
337 LinalgOp linalg_op,
338 Value partial_result) const {
339 SmallVector<mlir::StringRef, 3> reduction_iter_type(
340 1, mlir::getReductionIteratorTypeName());
341 auto map = mlir::AffineMap::get(1, 0, llvm::None, rewriter.getContext());
342
343 auto combiner_or = DetectCombiner(linalg_op);
344 if (failed(combiner_or)) return failure();
345 Operation *combiner = combiner_or.getValue();
346
347 auto accumulator = rewriter.create<GenericOp>(
348 linalg_op.getLoc(), linalg_op->getResultTypes(),
349 makeArrayRef(partial_result),
350 makeArrayRef(linalg_op.getOutputOperand(0)->get()),
351 makeArrayRef({rewriter.getMultiDimIdentityMap(1), map}),
352 reduction_iter_type,
353 [&](OpBuilder &b, Location nested_loc, ValueRange args) {
354 BlockAndValueMapping bvm;
355 bvm.map(combiner->getOperands(), args);
356 Value result_val = b.clone(*combiner, bvm)->getResult(0);
357 b.create<mlir::linalg::YieldOp>(nested_loc, result_val);
358 });
359 return accumulator;
360 }
361
362 private:
363 int64_t vector_size;
364 int64_t tile_size;
365 };
366
367 struct TileReductionPass : public TileReductionBase<TileReductionPass> {
368 TileReductionPass() = default;
TileReductionPasstensorflow::__anonf12ed84e0111::TileReductionPass369 TileReductionPass(int64_t vector_size, int64_t reduction_1d_tile,
370 llvm::ArrayRef<int64_t> reduction_2d_tiles) {
371 reduction_vector_size = vector_size;
372 reduction_1d_tile_size = reduction_1d_tile;
373 reduction_2d_tile_sizes = reduction_2d_tiles;
374 }
runOnOperationtensorflow::__anonf12ed84e0111::TileReductionPass375 void runOnOperation() override {
376 auto func = getOperation();
377 auto context = func.getContext();
378
379 assert(reduction_1d_tile_size % reduction_vector_size == 0 &&
380 "Tile size for 1D reduction should be a multiple of vector size");
381 auto patterns =
382 mlir::linalg::getLinalgTilingCanonicalizationPatterns(context);
383 patterns.add<OneDimReductionTilingPattern>(
384 reduction_vector_size, reduction_1d_tile_size, patterns.getContext());
385
386 assert(reduction_2d_tile_sizes.size() == 2 &&
387 "Tiling sizes for 2D reductions should have two elements");
388 patterns.add<RowOrColumnReductionTilingPattern>(
389 LinalgTilingOptions{}.setTileSizes(reduction_2d_tile_sizes),
390 patterns.getContext());
391 (void)mlir::applyPatternsAndFoldGreedily(func, std::move(patterns));
392
393 // Ensure we drop the marker in the end.
394 func.walk([](LinalgOp op) { removeTransformationAttr(op); });
395 }
396 };
397
398 } // namespace
399
400 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateTileReductionPass()401 CreateTileReductionPass() {
402 return std::make_unique<TileReductionPass>();
403 }
404
405 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateTileReductionPass(int64_t reduction_vector_size,int64_t reduction_1d_tile_size,llvm::ArrayRef<int64_t> reduction_2d_tile_sizes)406 CreateTileReductionPass(int64_t reduction_vector_size,
407 int64_t reduction_1d_tile_size,
408 llvm::ArrayRef<int64_t> reduction_2d_tile_sizes) {
409 return std::make_unique<TileReductionPass>(
410 reduction_vector_size, reduction_1d_tile_size, reduction_2d_tile_sizes);
411 }
412
413 } // namespace tensorflow
414