1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
20 #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h"
21 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/DialectConversion.h"
29
30 namespace mlir {
31 namespace lmhlo {
32 namespace {
33
34 // Clones and adapts the code in `lhlo_block` that works on buffers and has a
35 // single output buffer to make it compatible with `operands` that have element
36 // types of the respective buffers. Returns the computed value.
37 //
38 // Example. For `operands` with (f32, i32) types and a block with LHLO ops and
39 // with signature:
40 // ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
41 // <LHLO_ops>
42 //
43 // inserts necessary alloc and store ops to compute and return result that has
44 // `i1` type.
applySingleResultLhloCode(Location loc,ValueRange operands,Block * lhloBlock,OpBuilder * b)45 Value applySingleResultLhloCode(Location loc, ValueRange operands,
46 Block* lhloBlock, OpBuilder* b) {
47 SmallVector<Value, 2> argBufs;
48 for (auto argType : lhloBlock->getArgumentTypes()) {
49 argBufs.push_back(
50 b->create<memref::AllocOp>(loc, argType.cast<MemRefType>()));
51 }
52 for (const auto& operand : llvm::enumerate(operands)) {
53 b->create<memref::StoreOp>(loc, operand.value(), argBufs[operand.index()]);
54 }
55 // Clone the ops from `lhlo_block`.
56 BlockAndValueMapping mapping;
57 mapping.map(lhloBlock->getArguments(), argBufs);
58 for (auto& nested : lhloBlock->without_terminator()) {
59 auto* clone = b->clone(nested, mapping);
60 mapping.map(nested.getResults(), clone->getResults());
61 }
62 return b->create<memref::LoadOp>(loc, argBufs.back());
63 }
64
65 // Converts a block with LHLO ops and with signature:
66 // ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
67 // into a reduction operator of scf.reduce by doing buffer allocation for
68 // scalar arguments and the result of `scf.reduce` to make it compatible with
69 // LHLO ops.
convertToReductionOperator(Location loc,scf::ReduceOp reduceOp,Block * lhloBlock,OpBuilder * b)70 void convertToReductionOperator(Location loc, scf::ReduceOp reduceOp,
71 Block* lhloBlock, OpBuilder* b) {
72 Block& loopReduceOpBody = reduceOp.getReductionOperator().front();
73 OpBuilder::InsertionGuard guard(*b);
74 b->setInsertionPointToStart(&loopReduceOpBody);
75 b->create<scf::ReduceReturnOp>(
76 loc, applySingleResultLhloCode(loc, loopReduceOpBody.getArguments(),
77 lhloBlock, b));
78 }
79
80 // Returns result of arith::ConstantOp if `dim` is static, otherwise uses DimOp
81 // to extract dimension at runtime.
getStaticOrDynamicDim(mlir::Location loc,Value shapedValue,size_t dimIndex,int64_t dim,OpBuilder * b)82 Value getStaticOrDynamicDim(mlir::Location loc, Value shapedValue,
83 size_t dimIndex, int64_t dim, OpBuilder* b) {
84 return dim == ShapedType::kDynamicSize
85 ? b->create<memref::DimOp>(loc, shapedValue, dimIndex).getResult()
86 : b->create<arith::ConstantIndexOp>(loc, dim);
87 }
88
89 struct MappedIvs {
90 // False if the mapped indices are in the padding area, true otherwise.
91 Value inBounds;
92 // Mapped indices.
93 SmallVector<Value, 2> ivs;
94 };
95
96 template <typename OpTy>
mapWindowIvsToInput(OpTy op,Value operand,ValueRange ivs,ValueRange windowIvs,OpBuilder * b)97 MappedIvs mapWindowIvsToInput(OpTy op, Value operand, ValueRange ivs,
98 ValueRange windowIvs, OpBuilder* b) {
99 MappedIvs mappedIvs;
100
101 if (!op.getWindowStrides().has_value()) {
102 op.emitOpError("No window strides specified.");
103 }
104 auto windowStrides = op.getWindowStrides().value();
105
106 if (!op.getPadding().has_value()) {
107 op.emitOpError("No padding specified.");
108 }
109 auto padding = op.getPadding().value();
110
111 auto loc = op.getLoc();
112 auto operandShape = operand.getType().template cast<MemRefType>().getShape();
113
114 // `in_bounds` is false when the mapped indices are in the padding area.
115 mappedIvs.inBounds = b->create<mlir::arith::ConstantOp>(
116 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
117 for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
118 auto stride = windowStrides.template getValues<llvm::APInt>()[i];
119 auto padLow = padding.template getValues<llvm::APInt>()[{i, 0}];
120
121 Value strideVal =
122 b->create<arith::ConstantIndexOp>(loc, stride.getSExtValue());
123 Value padLowVal =
124 b->create<arith::ConstantIndexOp>(loc, padLow.getSExtValue());
125
126 Value center = b->create<arith::MulIOp>(loc, ivs[i], strideVal);
127 Value offset = b->create<arith::SubIOp>(loc, windowIvs[i], padLowVal);
128 Value index = b->create<arith::AddIOp>(loc, center, offset);
129 Value upperBound =
130 getStaticOrDynamicDim(loc, operand, i, operandShape[i], b);
131 // We must check whether 0 <= index_i < shape_i, as otherwise we are in
132 // the pad and then we have to use the neutral element for reduction.
133 // Equivalently, it can be computed as the unsigned comparison index_i <
134 // shape_i, since a negative value wraps to a large positive value.
135 mappedIvs.inBounds = b->create<mlir::arith::AndIOp>(
136 loc, mappedIvs.inBounds,
137 b->create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, index,
138 upperBound));
139 mappedIvs.ivs.push_back(index);
140 }
141 return mappedIvs;
142 }
143
144 // Returns scf::Parallel over a shaped value with static or dynamic shape.
makeLoopOverShape(Location loc,Value shapedValue,OpBuilder * b)145 scf::ParallelOp makeLoopOverShape(Location loc, Value shapedValue,
146 OpBuilder* b) {
147 Value zero = b->create<arith::ConstantIndexOp>(loc, 0);
148 Value one = b->create<arith::ConstantIndexOp>(loc, 1);
149
150 ArrayRef<int64_t> shape = shapedValue.getType().cast<ShapedType>().getShape();
151 SmallVector<Value, 2> lower, upper, step;
152 for (const auto& dim : llvm::enumerate(shape)) {
153 upper.push_back(
154 getStaticOrDynamicDim(loc, shapedValue, dim.index(), dim.value(), b));
155 lower.push_back(zero);
156 step.push_back(one);
157 }
158 return b->create<scf::ParallelOp>(loc, lower, upper, step);
159 }
160
161 // Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
162 // The outper `ParallelOp` refers to the parallel loops if there are
163 // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
164 // contains the reduction operator.
165 //
166 // Example:
167 //
168 // "lmhlo.reduce"(%buffer, %init_buf, %result) ({
169 // ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
170 // <LHLO ops>
171 // } ) {dimensions = dense<[1]> : tensor<1xi64>}
172 // : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
173 //
174 // is roughly converted into:
175 //
176 // %init = load %init_buf[] : memref<f32>
177 // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
178 // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
179 // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
180 // scf.reduce(%elem_to_reduce) {
181 // ^bb0(%elem: f32, %acc: f32):
182 // elem_buf = alloc() : memref<f32>
183 // store %elem, elem_buf[] : memref<f32>
184 // acc_buf = alloc() : memref<f32>
185 // store %acc, acc_buf[] : memref<f32>
186 // <LHLO_ops>
187 // %acc_result = load acc_buf[] : memref<f32>
188 // scf.reduce.return %acc_result : f32
189 // } : f32
190 // scf.yield
191 // } : f32
192 // scf.yield
193 // }
194 class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
195 public:
196 using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
197
matchAndRewrite(lmhlo::ReduceOp reduceOp,OpAdaptor,ConversionPatternRewriter & rewriter) const198 LogicalResult matchAndRewrite(
199 lmhlo::ReduceOp reduceOp, OpAdaptor /*adaptor*/,
200 ConversionPatternRewriter& rewriter) const final {
201 // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
202 if (reduceOp.getOut().size() != 1) return failure();
203
204 scf::ReduceOp scfReduceOp =
205 createReduceOpInNestedParallelLoops(reduceOp, &rewriter);
206 convertToReductionOperator(reduceOp.getLoc(), scfReduceOp,
207 &reduceOp.getBody().front(), &rewriter);
208 rewriter.replaceOp(reduceOp, llvm::None);
209 return success();
210 }
211
212 private:
213 // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
214 // refers to the parallel dimensions of `reduce_op` if any and the inner
215 // ParallelOp refers to the reduction dimensions. The scf.reduce op is
216 // returned.
217 //
218 // If the reduction argument is a memref<100x10x5xf32> and the
219 // reduction is performed along dimension 1 then this method will generate
220 //
221 // %init = load %init_buf[] : memref<f32>
222 // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
223 // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
224 // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
225 // scf.reduce(%elem_to_reduce) {
226 // <THE BLOCK PTR TO BE RETURNED>
227 // } : f32
228 // scf.yield
229 // } : f32
230 // scf.yield
231 // }
createReduceOpInNestedParallelLoops(lmhlo::ReduceOp reduceOp,ConversionPatternRewriter * rewriter) const232 scf::ReduceOp createReduceOpInNestedParallelLoops(
233 lmhlo::ReduceOp reduceOp, ConversionPatternRewriter* rewriter) const {
234 auto loc = reduceOp.getLoc();
235 DenseSet<int> reducingDims;
236 for (const auto& rdim : reduceOp.getDimensions().getValues<APInt>()) {
237 reducingDims.insert(rdim.getSExtValue());
238 }
239
240 Value operand = reduceOp.getInputs().front();
241 Value out = reduceOp.getOut().front();
242 SmallVector<Value, 2> parallelLower, parallelUpper, parallelStep;
243 SmallVector<Value, 2> reduceLower, reduceUpper, reduceStep;
244 auto operandShape = operand.getType().cast<MemRefType>().getShape();
245 for (const auto& dim : llvm::enumerate(operandShape)) {
246 const bool isReducingDim = reducingDims.count(dim.index());
247
248 Value ub = getStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
249 rewriter);
250 Value lb = rewriter->create<arith::ConstantIndexOp>(loc, 0);
251 Value step = rewriter->create<arith::ConstantIndexOp>(loc, 1);
252 (isReducingDim ? reduceLower : parallelLower).push_back(lb);
253 (isReducingDim ? reduceUpper : parallelUpper).push_back(ub);
254 (isReducingDim ? reduceStep : parallelStep).push_back(step);
255 }
256 // Load initial value from memref<element_type>.
257 SmallVector<Value, 1> initValue = {rewriter->create<memref::LoadOp>(
258 loc, *reduceOp.getInitValues().begin())};
259 // Outer ParallelOp is not needed if it is a reduction across all dims.
260 scf::ParallelOp outer;
261 if (!parallelLower.empty()) {
262 outer = rewriter->create<scf::ParallelOp>(loc, parallelLower,
263 parallelUpper, parallelStep);
264 rewriter->setInsertionPointToStart(outer.getBody());
265 }
266 scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
267 loc, reduceLower, reduceUpper, reduceStep, ValueRange(initValue));
268 Value reductionResult = *inner.getResults().begin();
269
270 SmallVector<Value, 1> outIndices;
271 if (outer != nullptr) {
272 outIndices.reserve(outer.getNumLoops());
273 for (Value iv : outer.getInductionVars()) {
274 outIndices.push_back(iv);
275 }
276 } else {
277 outIndices.push_back(rewriter->create<arith::ConstantIndexOp>(loc, 0));
278 }
279
280 rewriter->create<memref::StoreOp>(loc, reductionResult, out, outIndices);
281
282 // Load the element to reduce.
283 SmallVector<Value, 2> indices;
284 indices.reserve(operandShape.size());
285
286 if (outer) {
287 auto innerIvsIt = inner.getInductionVars().begin();
288 auto outerIvsIt = outer.getInductionVars().begin();
289 for (unsigned i = 0, e = operandShape.size(); i < e; ++i) {
290 indices.push_back(reducingDims.count(i) ? *innerIvsIt++
291 : *outerIvsIt++);
292 }
293 } else {
294 indices = inner.getInductionVars();
295 }
296
297 rewriter->setInsertionPointToStart(inner.getBody());
298 Value elem = rewriter->create<mlir::memref::LoadOp>(
299 loc, reduceOp.getInputs().front(), indices);
300 return rewriter->create<scf::ReduceOp>(loc, elem);
301 }
302 };
303
304 // Pseudocode:
305 // for each index O in output
306 // accumulator = neutral_value
307 // in_bounds = true
308 // for each index W in window
309 // for each dimension i from 0 to rank - 1
310 // index = O[i] * stride[i] + W[i] - pad_low[i]
311 // in_bounds = inbounds && (index `ult` shape[i])
312 // I[i] = index
313 // if (in_bounds)
314 // value = input[I]
315 // else
316 // value = neutral_value
317 // accumulator = reduction_operator(accumulator, value)
318 // output[O] = accumulator
319 //
320 // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
321 // scf::ReduceOp.
322 // The outper `ParallelOp` refers to the parallel loops that traverese output
323 // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
324 // reduction windows and `ReduceOp` contains the reduction operator.
325 //
326 // Example:
327 //
328 // func @reduce_window(%arg: memref<112x112xf32>,
329 // %init: memref<f32>,
330 // %result: memref<56x56xf32>) {
331 // "lmhlo.reduce_window"(%arg, %init, %result) ({
332 // ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
333 // "lmhlo.maximum"(%lhs, %rhs, %res)
334 // : (memref<f32>, memref<f32>, memref<f32>) -> ()
335 // "lmhlo.terminator"() : () -> ()
336 // }) {
337 // padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
338 // window_dimensions = dense<[3, 3]> : tensor<2xi64>,
339 // window_strides = dense<[2, 2]> : tensor<2xi64>
340 // } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
341 // return
342 // }
343 //
344 // is roughly converted into:
345 //
346 // %neutral_elem = load %init_buf[] : memref<f32>
347 // scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
348 // %result = scf.parallel (%iw, %jw) = (%c0, %c0)
349 // to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
350 // %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
351 // %elem = load %operand[%computed_i, %computed_j]
352 // %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
353 // scf.reduce(%elem_to_reduce) : f32 {
354 // ^bb0(%arg7: f32, %arg8: f32):
355 // <LHLO ops>
356 // }
357 // scf.yield
358 // }
359 // store %result, %output_buffer[%i, %j] : memref<56x56xf32>
360 // scf.yield
361 // }
362 // return
363 // }
364 class ReduceWindowOpConverter
365 : public OpConversionPattern<lmhlo::ReduceWindowOp> {
366 public:
367 using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
368
matchAndRewrite(lmhlo::ReduceWindowOp reduceWindowOp,OpAdaptor,ConversionPatternRewriter & rewriter) const369 LogicalResult matchAndRewrite(
370 lmhlo::ReduceWindowOp reduceWindowOp, OpAdaptor /*adaptor*/,
371 ConversionPatternRewriter& rewriter) const final {
372 // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
373 if (reduceWindowOp.getOut().size() != 1) return failure();
374
375 scf::ParallelOp outputLoop, windowLoop;
376 std::tie(outputLoop, windowLoop) =
377 createParallelLoopsToTraverseOutputAndWindow(reduceWindowOp, &rewriter);
378
379 scf::ReduceOp reduceOp = createReduceOpInNestedParallelLoops(
380 reduceWindowOp, outputLoop, windowLoop, &rewriter);
381
382 convertToReductionOperator(reduceWindowOp.getLoc(), reduceOp,
383 &reduceWindowOp.getBody().front(), &rewriter);
384 rewriter.replaceOp(reduceWindowOp, llvm::None);
385 return success();
386 }
387
388 private:
389 std::pair<scf::ParallelOp, scf::ParallelOp>
createParallelLoopsToTraverseOutputAndWindow(lmhlo::ReduceWindowOp reduceWindowOp,ConversionPatternRewriter * rewriter) const390 createParallelLoopsToTraverseOutputAndWindow(
391 lmhlo::ReduceWindowOp reduceWindowOp,
392 ConversionPatternRewriter* rewriter) const {
393 auto loc = reduceWindowOp.getLoc();
394 Value initValue = rewriter->create<memref::LoadOp>(
395 loc, reduceWindowOp.getInitValues()[0]);
396
397 Value zero = rewriter->create<arith::ConstantIndexOp>(loc, 0);
398 Value one = rewriter->create<arith::ConstantIndexOp>(loc, 1);
399
400 // Create an outer parallel loop that spans the output of ReduceWindowOp.
401 Value output = reduceWindowOp.getOut()[0];
402 auto outputLoop = makeLoopOverShape(loc, output, rewriter);
403
404 // Create a nested loop that traverses the window.
405 SmallVector<Value, 2> windowLower, windowUpper, windowStep;
406 rewriter->setInsertionPointToStart(outputLoop.getBody());
407 for (const auto& windowDim : reduceWindowOp.getWindowDimensions()) {
408 windowStep.push_back(one);
409 windowLower.push_back(zero);
410 windowUpper.push_back(rewriter->create<arith::ConstantIndexOp>(
411 loc, windowDim.getSExtValue()));
412 }
413 auto windowLoop = rewriter->create<scf::ParallelOp>(
414 loc, windowLower, windowUpper, windowStep, ValueRange(initValue));
415
416 Value reductionResult = *windowLoop.getResults().begin();
417 auto outputIvs = outputLoop.getInductionVars();
418 rewriter->create<memref::StoreOp>(loc, reductionResult, output, outputIvs);
419 return std::make_pair(outputLoop, windowLoop);
420 }
421
createReduceOpInNestedParallelLoops(lmhlo::ReduceWindowOp reduceWindowOp,scf::ParallelOp outputLoop,scf::ParallelOp windowLoop,ConversionPatternRewriter * rewriter) const422 scf::ReduceOp createReduceOpInNestedParallelLoops(
423 lmhlo::ReduceWindowOp reduceWindowOp, scf::ParallelOp outputLoop,
424 scf::ParallelOp windowLoop, ConversionPatternRewriter* rewriter) const {
425 rewriter->setInsertionPointToStart(windowLoop.getBody());
426 auto loc = reduceWindowOp.getLoc();
427
428 if (reduceWindowOp.getBaseDilations().has_value() ||
429 reduceWindowOp.getWindowDilations().has_value()) {
430 reduceWindowOp.emitRemark(
431 "Lowering to parallel loops does not support `base_dilations` or "
432 "`window_dilations` attributes yet. The attributes will be ignored.");
433 }
434
435 Value input = reduceWindowOp.getInputs()[0];
436 auto inputType = input.getType().cast<MemRefType>();
437
438 // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
439 MappedIvs mappedIvs = mapWindowIvsToInput(
440 reduceWindowOp, input, outputLoop.getInductionVars(),
441 windowLoop.getInductionVars(), rewriter);
442
443 auto elemOrInit = rewriter->create<scf::IfOp>(
444 loc, inputType.getElementType(), mappedIvs.inBounds,
445 /*withElseRegion=*/true);
446
447 OpBuilder thenBuilder =
448 elemOrInit.getThenBodyBuilder(rewriter->getListener());
449 Value elem =
450 thenBuilder.create<mlir::memref::LoadOp>(loc, input, mappedIvs.ivs);
451 thenBuilder.create<scf::YieldOp>(loc, elem);
452
453 OpBuilder elseBuilder =
454 elemOrInit.getElseBodyBuilder(rewriter->getListener());
455 elseBuilder.create<scf::YieldOp>(loc, *windowLoop.getInitVals().begin());
456
457 return rewriter->create<scf::ReduceOp>(loc,
458 *elemOrInit.getResults().begin());
459 }
460 };
461
462 // See the operation semantics in
463 // https://www.tensorflow.org/xla/operation_semantics#selectandscatter
464 //
465 // Pseudocode:
466 // scf.parallel(coordinates O in the output):
467 // output[O] = init
468 // scf.parallel(coordinates S in the source):
469 // selected_ivs = 0
470 // selected_val = 0
471 // initialized_flag = false
472 // scf.for (first dim W_1 in the window)
473 // iter_args (selected_ivs, selected_val, initialized_flag):
474 // ...
475 // scf.for (last dim W_N in the window):
476 // iter_args (selected_ivs, selected_val, initialized_flag):
477 // I = S * stride + W - pad_low
478 // if I within bounds of operand:
479 // if (initialized_flag):
480 // pred = select(selected_value, operand(I))):
481 // if (pred)
482 // selected_value = operand(I)
483 // selected_index = I
484 // else
485 // selected_value = operand(I)
486 // selected_index = I
487 // initialized_flag = true
488 // output(selected_index) = scatter(output(selected_index), source(S))
489 class SelectAndScatterOpConverter
490 : public OpConversionPattern<lmhlo::SelectAndScatterOp> {
491 public:
492 using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
493
matchAndRewrite(lmhlo::SelectAndScatterOp sAndSOp,OpAdaptor,ConversionPatternRewriter & rewriter) const494 LogicalResult matchAndRewrite(
495 lmhlo::SelectAndScatterOp sAndSOp, OpAdaptor /*adaptor*/,
496 ConversionPatternRewriter& rewriter) const final {
497 auto loc = sAndSOp.getLoc();
498 initializeOutput(sAndSOp, &rewriter);
499 scf::ParallelOp loopOverSrc =
500 makeLoopOverShape(loc, sAndSOp.getSource(), &rewriter);
501 rewriter.setInsertionPointToStart(loopOverSrc.getBody());
502
503 // Compute indices of the selected element in the window.
504 auto selectedIvs = selectIvs(sAndSOp, loopOverSrc, &rewriter);
505
506 // Load `source[selected_ivs]`.
507 auto srcElem = rewriter.create<memref::LoadOp>(
508 loc, sAndSOp.getSource(), loopOverSrc.getInductionVars());
509
510 // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
511 auto rmw = rewriter.create<memref::GenericAtomicRMWOp>(
512 loc, sAndSOp.getOut(), selectedIvs);
513 OpBuilder rmwBuilder = OpBuilder::atBlockEnd(rmw.getBody());
514 auto accResult =
515 applySingleResultLhloCode(loc, {srcElem, rmw.getCurrentValue()},
516 &sAndSOp.getScatter().front(), &rmwBuilder);
517 rmwBuilder.create<memref::AtomicYieldOp>(loc, accResult);
518
519 rewriter.replaceOp(sAndSOp, llvm::None);
520 return success();
521 }
522
523 private:
initializeOutput(lmhlo::SelectAndScatterOp sAndSOp,OpBuilder * b) const524 void initializeOutput(lmhlo::SelectAndScatterOp sAndSOp, OpBuilder* b) const {
525 auto loc = sAndSOp.getLoc();
526 Value initValue = b->create<memref::LoadOp>(loc, sAndSOp.getInitValue());
527
528 scf::ParallelOp loopOverOutput =
529 makeLoopOverShape(loc, sAndSOp.getOut(), b);
530 OpBuilder::InsertionGuard guard(*b);
531 b->setInsertionPointToStart(loopOverOutput.getBody());
532 b->create<memref::StoreOp>(loc, initValue, sAndSOp.getOut(),
533 loopOverOutput.getInductionVars());
534 }
535
536 struct WindowLoops {
537 SmallVector<Value, 2> selectedIvs;
538 SmallVector<Value, 2> windowIvs;
539 scf::ForOp innerLoop;
540 };
insertWindowLoops(lmhlo::SelectAndScatterOp sAndSOp,scf::ParallelOp loopOverSrc,OpBuilder * b) const541 WindowLoops insertWindowLoops(lmhlo::SelectAndScatterOp sAndSOp,
542 scf::ParallelOp loopOverSrc,
543 OpBuilder* b) const {
544 auto loc = sAndSOp.getLoc();
545 Value zero = b->create<arith::ConstantIndexOp>(loc, 0);
546 Value one = b->create<arith::ConstantIndexOp>(loc, 1);
547
548 auto elementType =
549 sAndSOp.getOut().getType().cast<MemRefType>().getElementType();
550 auto rank = loopOverSrc.getNumLoops();
551
552 // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
553 SmallVector<Value, 4> iterArgs(rank, zero);
554 iterArgs.push_back(b->create<mlir::arith::ConstantOp>(
555 loc, elementType, b->getFloatAttr(elementType, 0)));
556 iterArgs.push_back(b->create<mlir::arith::ConstantOp>(
557 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
558
559 // Create a nested loop that traverses the window.
560 OpBuilder::InsertPoint ip;
561 WindowLoops result;
562 for (const auto& windowDim :
563 sAndSOp.getWindowDimensions()->getValues<APInt>()) {
564 Value upper =
565 b->create<arith::ConstantIndexOp>(loc, windowDim.getSExtValue());
566 result.innerLoop = b->create<scf::ForOp>(loc, zero, upper, one, iterArgs);
567 if (b->getInsertionBlock() == loopOverSrc.getBody()) {
568 ip = b->saveInsertionPoint();
569 result.selectedIvs = result.innerLoop.getResults().take_front(rank);
570 } else {
571 b->create<scf::YieldOp>(loc, result.innerLoop.getResults());
572 }
573 b->setInsertionPointToStart(result.innerLoop.getBody());
574 iterArgs = ValueRange{result.innerLoop.getRegionIterArgs()};
575 result.windowIvs.push_back(result.innerLoop.getInductionVar());
576 }
577 b->restoreInsertionPoint(ip);
578 return result;
579 }
580
581 // Adapter to store iteration arguments of sequential loops that perform
582 // select in a window.
583 class IterArgs {
584 public:
IterArgs(ValueRange ivsValFlag)585 explicit IterArgs(ValueRange ivsValFlag) : ivsValFlag(ivsValFlag) {}
IterArgs(ValueRange ivs,Value value,Value flag)586 IterArgs(ValueRange ivs, Value value, Value flag) {
587 ivsValFlag = ivs;
588 ivsValFlag.push_back(value);
589 ivsValFlag.push_back(flag);
590 }
591
toVector() const592 ArrayRef<Value> toVector() const { return ivsValFlag; }
593
594 // Indices of the currently selected value.
ivs() const595 ArrayRef<Value> ivs() const { return toVector().drop_back(2); }
596 // Currently selected value w.r.t. select() function.
value() const597 Value value() const { return ivsValFlag.end()[-2]; }
598 // i1 flag if value() and ivs() were initialized.
isInit() const599 Value isInit() const { return ivsValFlag.back(); }
600
601 private:
602 // Vector that stores iv_1, ..., iv_N, value, init.
603 SmallVector<Value, 4> ivsValFlag;
604 };
605
selectIvs(lmhlo::SelectAndScatterOp sAndSOp,scf::ParallelOp loopOverSrc,OpBuilder * b) const606 SmallVector<Value, 2> selectIvs(lmhlo::SelectAndScatterOp sAndSOp,
607 scf::ParallelOp loopOverSrc,
608 OpBuilder* b) const {
609 auto loc = sAndSOp.getLoc();
610
611 WindowLoops windowLoops = insertWindowLoops(sAndSOp, loopOverSrc, b);
612 auto innerLoopB = OpBuilder::atBlockEnd(windowLoops.innerLoop.getBody());
613
614 // Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
615 MappedIvs mappedIvs = mapWindowIvsToInput(
616 sAndSOp, sAndSOp.getOperand(), loopOverSrc.getInductionVars(),
617 windowLoops.windowIvs, &innerLoopB);
618
619 IterArgs ivsValFlag(windowLoops.innerLoop.getRegionIterArgs());
620
621 auto ifInBounds = innerLoopB.create<scf::IfOp>(
622 loc, windowLoops.innerLoop.getResultTypes(), mappedIvs.inBounds,
623 /*withElseRegion=*/true);
624
625 // Case when we are inside boundaries of 'arg' and not in the pad area.
626 {
627 OpBuilder inBoundsThenB = ifInBounds.getThenBodyBuilder(b->getListener());
628 auto selectOrInitResults = selectOrInitialize(
629 sAndSOp, mappedIvs.ivs, &ivsValFlag, &inBoundsThenB);
630 inBoundsThenB.create<scf::YieldOp>(loc, selectOrInitResults);
631 }
632
633 // Case when we are in the pad.
634 {
635 OpBuilder inBoundsElseB = ifInBounds.getElseBodyBuilder(b->getListener());
636 inBoundsElseB.create<scf::YieldOp>(loc, ivsValFlag.toVector());
637 }
638
639 innerLoopB.create<scf::YieldOp>(loc, ifInBounds.getResults());
640 return windowLoops.selectedIvs;
641 }
642
selectOrInitialize(lmhlo::SelectAndScatterOp sAndSOp,ArrayRef<Value> operandIvs,IterArgs * ivsValFlag,OpBuilder * b) const643 SmallVector<Value, 4> selectOrInitialize(lmhlo::SelectAndScatterOp sAndSOp,
644 ArrayRef<Value> operandIvs,
645 IterArgs* ivsValFlag,
646 OpBuilder* b) const {
647 auto loc = sAndSOp.getLoc();
648 Value trueI1 = b->create<mlir::arith::ConstantOp>(
649 loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
650
651 const TypeRange iterArgTypes{ValueRange{ivsValFlag->toVector()}};
652 Value operandElem =
653 b->create<memref::LoadOp>(loc, sAndSOp.getOperand(), operandIvs);
654 auto ifInit = b->create<scf::IfOp>(loc, iterArgTypes, ivsValFlag->isInit(),
655 /*withElseRegion=*/true);
656 // Init == true, i.e. iter args are already initialized with a selected
657 // element in boundaries of the operand. Select function has to be computed
658 // here.
659 {
660 OpBuilder ifInitThenB = ifInit.getThenBodyBuilder(b->getListener());
661
662 auto& lhloSelect = sAndSOp.getSelect().front();
663 Value pred = applySingleResultLhloCode(
664 loc, {operandElem, ivsValFlag->value()}, &lhloSelect, &ifInitThenB);
665
666 auto ifPred = ifInitThenB.create<scf::IfOp>(loc, iterArgTypes, pred,
667 /*withElseRegion=*/true);
668
669 // Pred == true, therefore pack newly selected ivs, val and init flag back
670 // to iter_args and return.
671 {
672 OpBuilder ifPredThenB = ifPred.getThenBodyBuilder(b->getListener());
673 ifPredThenB.create<scf::YieldOp>(
674 loc, IterArgs{operandIvs, operandElem, trueI1}.toVector());
675 }
676
677 // Pred == false, therefore return old iter_args.
678 {
679 OpBuilder ifPredElseB = ifPred.getElseBodyBuilder(b->getListener());
680 ifPredElseB.create<scf::YieldOp>(loc, ivsValFlag->toVector());
681 }
682
683 ifInitThenB.create<scf::YieldOp>(loc, ifPred.getResults());
684 }
685 // Init == false, i.e. only pad was visited before and this is the first
686 // element in the boundaries of the operand.
687 {
688 OpBuilder ifInitElseB = ifInit.getElseBodyBuilder(b->getListener());
689
690 ifInitElseB.create<scf::YieldOp>(
691 loc, IterArgs{operandIvs, operandElem, trueI1}.toVector());
692 }
693 return ifInit.getResults();
694 }
695 };
696
697 struct LhloLegalizeToParallelLoopsPass
698 : public LhloLegalizeToParallelLoopsPassBase<
699 LhloLegalizeToParallelLoopsPass> {
getDependentDialectsmlir::lmhlo::__anon91b1dc7e0111::LhloLegalizeToParallelLoopsPass700 void getDependentDialects(DialectRegistry& registry) const override {
701 registry.insert<arith::ArithmeticDialect, func::FuncDialect,
702 memref::MemRefDialect, scf::SCFDialect>();
703 }
704
runOnOperationmlir::lmhlo::__anon91b1dc7e0111::LhloLegalizeToParallelLoopsPass705 void runOnOperation() override {
706 auto func = getOperation();
707
708 RewritePatternSet patterns(&getContext());
709 // clang-format off
710 patterns.add<
711 ReduceOpConverter,
712 ReduceWindowOpConverter,
713 SelectAndScatterOpConverter
714 >(func.getContext());
715 // clang-format on
716
717 ConversionTarget target(getContext());
718 target.addLegalDialect<arith::ArithmeticDialect, linalg::LinalgDialect,
719 memref::MemRefDialect, func::FuncDialect,
720 scf::SCFDialect, LmhloDialect>();
721 target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
722 lmhlo::SelectAndScatterOp>();
723
724 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
725 signalPassFailure();
726 }
727 }
728 };
729 } // namespace
730
731 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeLhloToParallelLoopsPass()732 createLegalizeLhloToParallelLoopsPass() {
733 return std::make_unique<LhloLegalizeToParallelLoopsPass>();
734 }
735
736 } // namespace lmhlo
737 } // namespace mlir
738