1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
20 #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h"
21 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/Linalg/IR/Linalg.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 
30 namespace mlir {
31 namespace lmhlo {
32 namespace {
33 
34 // Clones and adapts the code in `lhlo_block` that works on buffers and has a
35 // single output buffer to make it compatible with `operands` that have element
36 // types of the respective buffers. Returns the computed value.
37 //
38 // Example. For `operands` with (f32, i32) types and a block with LHLO ops and
39 // with signature:
40 //   ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
41 //     <LHLO_ops>
42 //
43 // inserts necessary alloc and store ops to compute and return result that has
44 // `i1` type.
applySingleResultLhloCode(Location loc,ValueRange operands,Block * lhloBlock,OpBuilder * b)45 Value applySingleResultLhloCode(Location loc, ValueRange operands,
46                                 Block* lhloBlock, OpBuilder* b) {
47   SmallVector<Value, 2> argBufs;
48   for (auto argType : lhloBlock->getArgumentTypes()) {
49     argBufs.push_back(
50         b->create<memref::AllocOp>(loc, argType.cast<MemRefType>()));
51   }
52   for (const auto& operand : llvm::enumerate(operands)) {
53     b->create<memref::StoreOp>(loc, operand.value(), argBufs[operand.index()]);
54   }
55   // Clone the ops from `lhlo_block`.
56   BlockAndValueMapping mapping;
57   mapping.map(lhloBlock->getArguments(), argBufs);
58   for (auto& nested : lhloBlock->without_terminator()) {
59     auto* clone = b->clone(nested, mapping);
60     mapping.map(nested.getResults(), clone->getResults());
61   }
62   return b->create<memref::LoadOp>(loc, argBufs.back());
63 }
64 
65 // Converts a block with LHLO ops and with signature:
66 //   ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
67 // into a reduction operator of scf.reduce by doing buffer allocation for
68 // scalar arguments and the result of `scf.reduce` to make it compatible with
69 // LHLO ops.
convertToReductionOperator(Location loc,scf::ReduceOp reduceOp,Block * lhloBlock,OpBuilder * b)70 void convertToReductionOperator(Location loc, scf::ReduceOp reduceOp,
71                                 Block* lhloBlock, OpBuilder* b) {
72   Block& loopReduceOpBody = reduceOp.getReductionOperator().front();
73   OpBuilder::InsertionGuard guard(*b);
74   b->setInsertionPointToStart(&loopReduceOpBody);
75   b->create<scf::ReduceReturnOp>(
76       loc, applySingleResultLhloCode(loc, loopReduceOpBody.getArguments(),
77                                      lhloBlock, b));
78 }
79 
80 // Returns result of arith::ConstantOp if `dim` is static, otherwise uses DimOp
81 // to extract dimension at runtime.
getStaticOrDynamicDim(mlir::Location loc,Value shapedValue,size_t dimIndex,int64_t dim,OpBuilder * b)82 Value getStaticOrDynamicDim(mlir::Location loc, Value shapedValue,
83                             size_t dimIndex, int64_t dim, OpBuilder* b) {
84   return dim == ShapedType::kDynamicSize
85              ? b->create<memref::DimOp>(loc, shapedValue, dimIndex).getResult()
86              : b->create<arith::ConstantIndexOp>(loc, dim);
87 }
88 
89 struct MappedIvs {
90   // False if the mapped indices are in the padding area, true otherwise.
91   Value inBounds;
92   // Mapped indices.
93   SmallVector<Value, 2> ivs;
94 };
95 
96 template <typename OpTy>
mapWindowIvsToInput(OpTy op,Value operand,ValueRange ivs,ValueRange windowIvs,OpBuilder * b)97 MappedIvs mapWindowIvsToInput(OpTy op, Value operand, ValueRange ivs,
98                               ValueRange windowIvs, OpBuilder* b) {
99   MappedIvs mappedIvs;
100 
101   if (!op.getWindowStrides().has_value()) {
102     op.emitOpError("No window strides specified.");
103   }
104   auto windowStrides = op.getWindowStrides().value();
105 
106   if (!op.getPadding().has_value()) {
107     op.emitOpError("No padding specified.");
108   }
109   auto padding = op.getPadding().value();
110 
111   auto loc = op.getLoc();
112   auto operandShape = operand.getType().template cast<MemRefType>().getShape();
113 
114   // `in_bounds` is false when the mapped indices are in the padding area.
115   mappedIvs.inBounds = b->create<mlir::arith::ConstantOp>(
116       loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
117   for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
118     auto stride = windowStrides.template getValues<llvm::APInt>()[i];
119     auto padLow = padding.template getValues<llvm::APInt>()[{i, 0}];
120 
121     Value strideVal =
122         b->create<arith::ConstantIndexOp>(loc, stride.getSExtValue());
123     Value padLowVal =
124         b->create<arith::ConstantIndexOp>(loc, padLow.getSExtValue());
125 
126     Value center = b->create<arith::MulIOp>(loc, ivs[i], strideVal);
127     Value offset = b->create<arith::SubIOp>(loc, windowIvs[i], padLowVal);
128     Value index = b->create<arith::AddIOp>(loc, center, offset);
129     Value upperBound =
130         getStaticOrDynamicDim(loc, operand, i, operandShape[i], b);
131     // We must check whether 0 <= index_i < shape_i, as otherwise we are in
132     // the pad and then we have to use the neutral element for reduction.
133     // Equivalently, it can be computed as the unsigned comparison index_i <
134     // shape_i, since a negative value wraps to a large positive value.
135     mappedIvs.inBounds = b->create<mlir::arith::AndIOp>(
136         loc, mappedIvs.inBounds,
137         b->create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, index,
138                                  upperBound));
139     mappedIvs.ivs.push_back(index);
140   }
141   return mappedIvs;
142 }
143 
144 // Returns scf::Parallel over a shaped value with static or dynamic shape.
makeLoopOverShape(Location loc,Value shapedValue,OpBuilder * b)145 scf::ParallelOp makeLoopOverShape(Location loc, Value shapedValue,
146                                   OpBuilder* b) {
147   Value zero = b->create<arith::ConstantIndexOp>(loc, 0);
148   Value one = b->create<arith::ConstantIndexOp>(loc, 1);
149 
150   ArrayRef<int64_t> shape = shapedValue.getType().cast<ShapedType>().getShape();
151   SmallVector<Value, 2> lower, upper, step;
152   for (const auto& dim : llvm::enumerate(shape)) {
153     upper.push_back(
154         getStaticOrDynamicDim(loc, shapedValue, dim.index(), dim.value(), b));
155     lower.push_back(zero);
156     step.push_back(one);
157   }
158   return b->create<scf::ParallelOp>(loc, lower, upper, step);
159 }
160 
161 // Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
162 // The outper `ParallelOp` refers to the parallel loops if there are
163 // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
164 // contains the reduction operator.
165 //
166 // Example:
167 //
168 //  "lmhlo.reduce"(%buffer, %init_buf, %result) ({
169 //    ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
170 //      <LHLO ops>
171 //    } ) {dimensions = dense<[1]> : tensor<1xi64>}
172 //      : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
173 //
174 //  is roughly converted into:
175 //
176 //  %init = load %init_buf[] : memref<f32>
177 //  scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
178 //    %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
179 //      %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
180 //      scf.reduce(%elem_to_reduce)  {
181 //        ^bb0(%elem: f32, %acc: f32):
182 //          elem_buf = alloc() : memref<f32>
183 //          store %elem, elem_buf[] : memref<f32>
184 //          acc_buf = alloc() : memref<f32>
185 //          store %acc, acc_buf[] : memref<f32>
186 //          <LHLO_ops>
187 //          %acc_result = load acc_buf[] : memref<f32>
188 //          scf.reduce.return %acc_result : f32
189 //      } : f32
190 //      scf.yield
191 //    } : f32
192 //    scf.yield
193 //  }
194 class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
195  public:
196   using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
197 
matchAndRewrite(lmhlo::ReduceOp reduceOp,OpAdaptor,ConversionPatternRewriter & rewriter) const198   LogicalResult matchAndRewrite(
199       lmhlo::ReduceOp reduceOp, OpAdaptor /*adaptor*/,
200       ConversionPatternRewriter& rewriter) const final {
201     // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
202     if (reduceOp.getOut().size() != 1) return failure();
203 
204     scf::ReduceOp scfReduceOp =
205         createReduceOpInNestedParallelLoops(reduceOp, &rewriter);
206     convertToReductionOperator(reduceOp.getLoc(), scfReduceOp,
207                                &reduceOp.getBody().front(), &rewriter);
208     rewriter.replaceOp(reduceOp, llvm::None);
209     return success();
210   }
211 
212  private:
213   // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
214   // refers to the parallel dimensions of `reduce_op` if any and the inner
215   // ParallelOp refers to the reduction dimensions. The scf.reduce op is
216   // returned.
217   //
218   // If the reduction argument is a memref<100x10x5xf32> and the
219   // reduction is performed along dimension 1 then this method will generate
220   //
221   //  %init = load %init_buf[] : memref<f32>
222   //  scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
223   //    %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
224   //      %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
225   //      scf.reduce(%elem_to_reduce)  {
226   //        <THE BLOCK PTR TO BE RETURNED>
227   //      } : f32
228   //      scf.yield
229   //    } : f32
230   //    scf.yield
231   //  }
createReduceOpInNestedParallelLoops(lmhlo::ReduceOp reduceOp,ConversionPatternRewriter * rewriter) const232   scf::ReduceOp createReduceOpInNestedParallelLoops(
233       lmhlo::ReduceOp reduceOp, ConversionPatternRewriter* rewriter) const {
234     auto loc = reduceOp.getLoc();
235     DenseSet<int> reducingDims;
236     for (const auto& rdim : reduceOp.getDimensions().getValues<APInt>()) {
237       reducingDims.insert(rdim.getSExtValue());
238     }
239 
240     Value operand = reduceOp.getInputs().front();
241     Value out = reduceOp.getOut().front();
242     SmallVector<Value, 2> parallelLower, parallelUpper, parallelStep;
243     SmallVector<Value, 2> reduceLower, reduceUpper, reduceStep;
244     auto operandShape = operand.getType().cast<MemRefType>().getShape();
245     for (const auto& dim : llvm::enumerate(operandShape)) {
246       const bool isReducingDim = reducingDims.count(dim.index());
247 
248       Value ub = getStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
249                                        rewriter);
250       Value lb = rewriter->create<arith::ConstantIndexOp>(loc, 0);
251       Value step = rewriter->create<arith::ConstantIndexOp>(loc, 1);
252       (isReducingDim ? reduceLower : parallelLower).push_back(lb);
253       (isReducingDim ? reduceUpper : parallelUpper).push_back(ub);
254       (isReducingDim ? reduceStep : parallelStep).push_back(step);
255     }
256     // Load initial value from memref<element_type>.
257     SmallVector<Value, 1> initValue = {rewriter->create<memref::LoadOp>(
258         loc, *reduceOp.getInitValues().begin())};
259     // Outer ParallelOp is not needed if it is a reduction across all dims.
260     scf::ParallelOp outer;
261     if (!parallelLower.empty()) {
262       outer = rewriter->create<scf::ParallelOp>(loc, parallelLower,
263                                                 parallelUpper, parallelStep);
264       rewriter->setInsertionPointToStart(outer.getBody());
265     }
266     scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
267         loc, reduceLower, reduceUpper, reduceStep, ValueRange(initValue));
268     Value reductionResult = *inner.getResults().begin();
269 
270     SmallVector<Value, 1> outIndices;
271     if (outer != nullptr) {
272       outIndices.reserve(outer.getNumLoops());
273       for (Value iv : outer.getInductionVars()) {
274         outIndices.push_back(iv);
275       }
276     } else {
277       outIndices.push_back(rewriter->create<arith::ConstantIndexOp>(loc, 0));
278     }
279 
280     rewriter->create<memref::StoreOp>(loc, reductionResult, out, outIndices);
281 
282     // Load the element to reduce.
283     SmallVector<Value, 2> indices;
284     indices.reserve(operandShape.size());
285 
286     if (outer) {
287       auto innerIvsIt = inner.getInductionVars().begin();
288       auto outerIvsIt = outer.getInductionVars().begin();
289       for (unsigned i = 0, e = operandShape.size(); i < e; ++i) {
290         indices.push_back(reducingDims.count(i) ? *innerIvsIt++
291                                                 : *outerIvsIt++);
292       }
293     } else {
294       indices = inner.getInductionVars();
295     }
296 
297     rewriter->setInsertionPointToStart(inner.getBody());
298     Value elem = rewriter->create<mlir::memref::LoadOp>(
299         loc, reduceOp.getInputs().front(), indices);
300     return rewriter->create<scf::ReduceOp>(loc, elem);
301   }
302 };
303 
304 // Pseudocode:
305 // for each index O in output
306 //   accumulator = neutral_value
307 //   in_bounds = true
308 //   for each index W in window
309 //     for each dimension i from 0 to rank - 1
310 //       index = O[i] * stride[i] + W[i] - pad_low[i]
311 //       in_bounds = inbounds && (index `ult` shape[i])
312 //       I[i] = index
313 //     if (in_bounds)
314 //       value = input[I]
315 //     else
316 //       value = neutral_value
317 //     accumulator = reduction_operator(accumulator, value)
318 //   output[O] = accumulator
319 //
320 // Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
321 // scf::ReduceOp.
322 // The outper `ParallelOp` refers to the parallel loops that traverese output
323 // buffer. The inner `ParalleOp` refers to the reduction loops that traverse
324 // reduction windows and `ReduceOp` contains the reduction operator.
325 //
326 // Example:
327 //
328 // func @reduce_window(%arg: memref<112x112xf32>,
329 //              %init: memref<f32>,
330 //              %result: memref<56x56xf32>) {
331 //   "lmhlo.reduce_window"(%arg, %init, %result) ({
332 //     ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
333 //       "lmhlo.maximum"(%lhs, %rhs, %res)
334 //         : (memref<f32>, memref<f32>, memref<f32>) -> ()
335 //       "lmhlo.terminator"() : () -> ()
336 //     }) {
337 //       padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
338 //       window_dimensions = dense<[3, 3]> : tensor<2xi64>,
339 //       window_strides = dense<[2, 2]> : tensor<2xi64>
340 //     } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
341 //   return
342 // }
343 //
344 // is roughly converted into:
345 //
346 //    %neutral_elem = load %init_buf[] : memref<f32>
347 //    scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
348 //      %result = scf.parallel (%iw, %jw) = (%c0, %c0)
349 //                  to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
350 //        %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
351 //        %elem = load %operand[%computed_i, %computed_j]
352 //        %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
353 //        scf.reduce(%elem_to_reduce)  : f32 {
354 //          ^bb0(%arg7: f32, %arg8: f32):
355 //            <LHLO ops>
356 //        }
357 //        scf.yield
358 //      }
359 //      store %result, %output_buffer[%i, %j] : memref<56x56xf32>
360 //      scf.yield
361 //    }
362 //    return
363 //  }
364 class ReduceWindowOpConverter
365     : public OpConversionPattern<lmhlo::ReduceWindowOp> {
366  public:
367   using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
368 
matchAndRewrite(lmhlo::ReduceWindowOp reduceWindowOp,OpAdaptor,ConversionPatternRewriter & rewriter) const369   LogicalResult matchAndRewrite(
370       lmhlo::ReduceWindowOp reduceWindowOp, OpAdaptor /*adaptor*/,
371       ConversionPatternRewriter& rewriter) const final {
372     // TODO(b/183977252) : Handle variadic ReduceOp/ReduceWindowOp
373     if (reduceWindowOp.getOut().size() != 1) return failure();
374 
375     scf::ParallelOp outputLoop, windowLoop;
376     std::tie(outputLoop, windowLoop) =
377         createParallelLoopsToTraverseOutputAndWindow(reduceWindowOp, &rewriter);
378 
379     scf::ReduceOp reduceOp = createReduceOpInNestedParallelLoops(
380         reduceWindowOp, outputLoop, windowLoop, &rewriter);
381 
382     convertToReductionOperator(reduceWindowOp.getLoc(), reduceOp,
383                                &reduceWindowOp.getBody().front(), &rewriter);
384     rewriter.replaceOp(reduceWindowOp, llvm::None);
385     return success();
386   }
387 
388  private:
389   std::pair<scf::ParallelOp, scf::ParallelOp>
createParallelLoopsToTraverseOutputAndWindow(lmhlo::ReduceWindowOp reduceWindowOp,ConversionPatternRewriter * rewriter) const390   createParallelLoopsToTraverseOutputAndWindow(
391       lmhlo::ReduceWindowOp reduceWindowOp,
392       ConversionPatternRewriter* rewriter) const {
393     auto loc = reduceWindowOp.getLoc();
394     Value initValue = rewriter->create<memref::LoadOp>(
395         loc, reduceWindowOp.getInitValues()[0]);
396 
397     Value zero = rewriter->create<arith::ConstantIndexOp>(loc, 0);
398     Value one = rewriter->create<arith::ConstantIndexOp>(loc, 1);
399 
400     // Create an outer parallel loop that spans the output of ReduceWindowOp.
401     Value output = reduceWindowOp.getOut()[0];
402     auto outputLoop = makeLoopOverShape(loc, output, rewriter);
403 
404     // Create a nested loop that traverses the window.
405     SmallVector<Value, 2> windowLower, windowUpper, windowStep;
406     rewriter->setInsertionPointToStart(outputLoop.getBody());
407     for (const auto& windowDim : reduceWindowOp.getWindowDimensions()) {
408       windowStep.push_back(one);
409       windowLower.push_back(zero);
410       windowUpper.push_back(rewriter->create<arith::ConstantIndexOp>(
411           loc, windowDim.getSExtValue()));
412     }
413     auto windowLoop = rewriter->create<scf::ParallelOp>(
414         loc, windowLower, windowUpper, windowStep, ValueRange(initValue));
415 
416     Value reductionResult = *windowLoop.getResults().begin();
417     auto outputIvs = outputLoop.getInductionVars();
418     rewriter->create<memref::StoreOp>(loc, reductionResult, output, outputIvs);
419     return std::make_pair(outputLoop, windowLoop);
420   }
421 
createReduceOpInNestedParallelLoops(lmhlo::ReduceWindowOp reduceWindowOp,scf::ParallelOp outputLoop,scf::ParallelOp windowLoop,ConversionPatternRewriter * rewriter) const422   scf::ReduceOp createReduceOpInNestedParallelLoops(
423       lmhlo::ReduceWindowOp reduceWindowOp, scf::ParallelOp outputLoop,
424       scf::ParallelOp windowLoop, ConversionPatternRewriter* rewriter) const {
425     rewriter->setInsertionPointToStart(windowLoop.getBody());
426     auto loc = reduceWindowOp.getLoc();
427 
428     if (reduceWindowOp.getBaseDilations().has_value() ||
429         reduceWindowOp.getWindowDilations().has_value()) {
430       reduceWindowOp.emitRemark(
431           "Lowering to parallel loops does not support `base_dilations` or "
432           "`window_dilations` attributes yet. The attributes will be ignored.");
433     }
434 
435     Value input = reduceWindowOp.getInputs()[0];
436     auto inputType = input.getType().cast<MemRefType>();
437 
438     // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
439     MappedIvs mappedIvs = mapWindowIvsToInput(
440         reduceWindowOp, input, outputLoop.getInductionVars(),
441         windowLoop.getInductionVars(), rewriter);
442 
443     auto elemOrInit = rewriter->create<scf::IfOp>(
444         loc, inputType.getElementType(), mappedIvs.inBounds,
445         /*withElseRegion=*/true);
446 
447     OpBuilder thenBuilder =
448         elemOrInit.getThenBodyBuilder(rewriter->getListener());
449     Value elem =
450         thenBuilder.create<mlir::memref::LoadOp>(loc, input, mappedIvs.ivs);
451     thenBuilder.create<scf::YieldOp>(loc, elem);
452 
453     OpBuilder elseBuilder =
454         elemOrInit.getElseBodyBuilder(rewriter->getListener());
455     elseBuilder.create<scf::YieldOp>(loc, *windowLoop.getInitVals().begin());
456 
457     return rewriter->create<scf::ReduceOp>(loc,
458                                            *elemOrInit.getResults().begin());
459   }
460 };
461 
462 // See the operation semantics in
463 // https://www.tensorflow.org/xla/operation_semantics#selectandscatter
464 //
465 // Pseudocode:
466 //  scf.parallel(coordinates O in the output):
467 //    output[O] = init
468 //  scf.parallel(coordinates S in the source):
469 //    selected_ivs = 0
470 //    selected_val = 0
471 //    initialized_flag = false
472 //    scf.for (first dim W_1 in the window)
473 //         iter_args (selected_ivs, selected_val, initialized_flag):
474 //    ...
475 //      scf.for (last dim W_N in the window):
476 //           iter_args (selected_ivs, selected_val, initialized_flag):
477 //        I = S * stride + W - pad_low
478 //        if I within bounds of operand:
479 //          if (initialized_flag):
480 //            pred = select(selected_value, operand(I))):
481 //            if (pred)
482 //              selected_value = operand(I)
483 //              selected_index = I
484 //          else
485 //              selected_value = operand(I)
486 //              selected_index = I
487 //              initialized_flag = true
488 //    output(selected_index) = scatter(output(selected_index), source(S))
489 class SelectAndScatterOpConverter
490     : public OpConversionPattern<lmhlo::SelectAndScatterOp> {
491  public:
492   using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
493 
matchAndRewrite(lmhlo::SelectAndScatterOp sAndSOp,OpAdaptor,ConversionPatternRewriter & rewriter) const494   LogicalResult matchAndRewrite(
495       lmhlo::SelectAndScatterOp sAndSOp, OpAdaptor /*adaptor*/,
496       ConversionPatternRewriter& rewriter) const final {
497     auto loc = sAndSOp.getLoc();
498     initializeOutput(sAndSOp, &rewriter);
499     scf::ParallelOp loopOverSrc =
500         makeLoopOverShape(loc, sAndSOp.getSource(), &rewriter);
501     rewriter.setInsertionPointToStart(loopOverSrc.getBody());
502 
503     // Compute indices of the selected element in the window.
504     auto selectedIvs = selectIvs(sAndSOp, loopOverSrc, &rewriter);
505 
506     // Load `source[selected_ivs]`.
507     auto srcElem = rewriter.create<memref::LoadOp>(
508         loc, sAndSOp.getSource(), loopOverSrc.getInductionVars());
509 
510     // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
511     auto rmw = rewriter.create<memref::GenericAtomicRMWOp>(
512         loc, sAndSOp.getOut(), selectedIvs);
513     OpBuilder rmwBuilder = OpBuilder::atBlockEnd(rmw.getBody());
514     auto accResult =
515         applySingleResultLhloCode(loc, {srcElem, rmw.getCurrentValue()},
516                                   &sAndSOp.getScatter().front(), &rmwBuilder);
517     rmwBuilder.create<memref::AtomicYieldOp>(loc, accResult);
518 
519     rewriter.replaceOp(sAndSOp, llvm::None);
520     return success();
521   }
522 
523  private:
initializeOutput(lmhlo::SelectAndScatterOp sAndSOp,OpBuilder * b) const524   void initializeOutput(lmhlo::SelectAndScatterOp sAndSOp, OpBuilder* b) const {
525     auto loc = sAndSOp.getLoc();
526     Value initValue = b->create<memref::LoadOp>(loc, sAndSOp.getInitValue());
527 
528     scf::ParallelOp loopOverOutput =
529         makeLoopOverShape(loc, sAndSOp.getOut(), b);
530     OpBuilder::InsertionGuard guard(*b);
531     b->setInsertionPointToStart(loopOverOutput.getBody());
532     b->create<memref::StoreOp>(loc, initValue, sAndSOp.getOut(),
533                                loopOverOutput.getInductionVars());
534   }
535 
536   struct WindowLoops {
537     SmallVector<Value, 2> selectedIvs;
538     SmallVector<Value, 2> windowIvs;
539     scf::ForOp innerLoop;
540   };
insertWindowLoops(lmhlo::SelectAndScatterOp sAndSOp,scf::ParallelOp loopOverSrc,OpBuilder * b) const541   WindowLoops insertWindowLoops(lmhlo::SelectAndScatterOp sAndSOp,
542                                 scf::ParallelOp loopOverSrc,
543                                 OpBuilder* b) const {
544     auto loc = sAndSOp.getLoc();
545     Value zero = b->create<arith::ConstantIndexOp>(loc, 0);
546     Value one = b->create<arith::ConstantIndexOp>(loc, 1);
547 
548     auto elementType =
549         sAndSOp.getOut().getType().cast<MemRefType>().getElementType();
550     auto rank = loopOverSrc.getNumLoops();
551 
552     // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
553     SmallVector<Value, 4> iterArgs(rank, zero);
554     iterArgs.push_back(b->create<mlir::arith::ConstantOp>(
555         loc, elementType, b->getFloatAttr(elementType, 0)));
556     iterArgs.push_back(b->create<mlir::arith::ConstantOp>(
557         loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
558 
559     // Create a nested loop that traverses the window.
560     OpBuilder::InsertPoint ip;
561     WindowLoops result;
562     for (const auto& windowDim :
563          sAndSOp.getWindowDimensions()->getValues<APInt>()) {
564       Value upper =
565           b->create<arith::ConstantIndexOp>(loc, windowDim.getSExtValue());
566       result.innerLoop = b->create<scf::ForOp>(loc, zero, upper, one, iterArgs);
567       if (b->getInsertionBlock() == loopOverSrc.getBody()) {
568         ip = b->saveInsertionPoint();
569         result.selectedIvs = result.innerLoop.getResults().take_front(rank);
570       } else {
571         b->create<scf::YieldOp>(loc, result.innerLoop.getResults());
572       }
573       b->setInsertionPointToStart(result.innerLoop.getBody());
574       iterArgs = ValueRange{result.innerLoop.getRegionIterArgs()};
575       result.windowIvs.push_back(result.innerLoop.getInductionVar());
576     }
577     b->restoreInsertionPoint(ip);
578     return result;
579   }
580 
581   // Adapter to store iteration arguments of sequential loops that perform
582   // select in a window.
583   class IterArgs {
584    public:
IterArgs(ValueRange ivsValFlag)585     explicit IterArgs(ValueRange ivsValFlag) : ivsValFlag(ivsValFlag) {}
IterArgs(ValueRange ivs,Value value,Value flag)586     IterArgs(ValueRange ivs, Value value, Value flag) {
587       ivsValFlag = ivs;
588       ivsValFlag.push_back(value);
589       ivsValFlag.push_back(flag);
590     }
591 
toVector() const592     ArrayRef<Value> toVector() const { return ivsValFlag; }
593 
594     // Indices of the currently selected value.
ivs() const595     ArrayRef<Value> ivs() const { return toVector().drop_back(2); }
596     // Currently selected value w.r.t. select() function.
value() const597     Value value() const { return ivsValFlag.end()[-2]; }
598     // i1 flag if value() and ivs() were initialized.
isInit() const599     Value isInit() const { return ivsValFlag.back(); }
600 
601    private:
602     // Vector that stores iv_1, ..., iv_N, value, init.
603     SmallVector<Value, 4> ivsValFlag;
604   };
605 
selectIvs(lmhlo::SelectAndScatterOp sAndSOp,scf::ParallelOp loopOverSrc,OpBuilder * b) const606   SmallVector<Value, 2> selectIvs(lmhlo::SelectAndScatterOp sAndSOp,
607                                   scf::ParallelOp loopOverSrc,
608                                   OpBuilder* b) const {
609     auto loc = sAndSOp.getLoc();
610 
611     WindowLoops windowLoops = insertWindowLoops(sAndSOp, loopOverSrc, b);
612     auto innerLoopB = OpBuilder::atBlockEnd(windowLoops.innerLoop.getBody());
613 
614     // Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
615     MappedIvs mappedIvs = mapWindowIvsToInput(
616         sAndSOp, sAndSOp.getOperand(), loopOverSrc.getInductionVars(),
617         windowLoops.windowIvs, &innerLoopB);
618 
619     IterArgs ivsValFlag(windowLoops.innerLoop.getRegionIterArgs());
620 
621     auto ifInBounds = innerLoopB.create<scf::IfOp>(
622         loc, windowLoops.innerLoop.getResultTypes(), mappedIvs.inBounds,
623         /*withElseRegion=*/true);
624 
625     // Case when we are inside boundaries of 'arg' and not in the pad area.
626     {
627       OpBuilder inBoundsThenB = ifInBounds.getThenBodyBuilder(b->getListener());
628       auto selectOrInitResults = selectOrInitialize(
629           sAndSOp, mappedIvs.ivs, &ivsValFlag, &inBoundsThenB);
630       inBoundsThenB.create<scf::YieldOp>(loc, selectOrInitResults);
631     }
632 
633     // Case when we are in the pad.
634     {
635       OpBuilder inBoundsElseB = ifInBounds.getElseBodyBuilder(b->getListener());
636       inBoundsElseB.create<scf::YieldOp>(loc, ivsValFlag.toVector());
637     }
638 
639     innerLoopB.create<scf::YieldOp>(loc, ifInBounds.getResults());
640     return windowLoops.selectedIvs;
641   }
642 
selectOrInitialize(lmhlo::SelectAndScatterOp sAndSOp,ArrayRef<Value> operandIvs,IterArgs * ivsValFlag,OpBuilder * b) const643   SmallVector<Value, 4> selectOrInitialize(lmhlo::SelectAndScatterOp sAndSOp,
644                                            ArrayRef<Value> operandIvs,
645                                            IterArgs* ivsValFlag,
646                                            OpBuilder* b) const {
647     auto loc = sAndSOp.getLoc();
648     Value trueI1 = b->create<mlir::arith::ConstantOp>(
649         loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
650 
651     const TypeRange iterArgTypes{ValueRange{ivsValFlag->toVector()}};
652     Value operandElem =
653         b->create<memref::LoadOp>(loc, sAndSOp.getOperand(), operandIvs);
654     auto ifInit = b->create<scf::IfOp>(loc, iterArgTypes, ivsValFlag->isInit(),
655                                        /*withElseRegion=*/true);
656     // Init == true, i.e. iter args are already initialized with a selected
657     // element in boundaries of the operand. Select function has to be computed
658     // here.
659     {
660       OpBuilder ifInitThenB = ifInit.getThenBodyBuilder(b->getListener());
661 
662       auto& lhloSelect = sAndSOp.getSelect().front();
663       Value pred = applySingleResultLhloCode(
664           loc, {operandElem, ivsValFlag->value()}, &lhloSelect, &ifInitThenB);
665 
666       auto ifPred = ifInitThenB.create<scf::IfOp>(loc, iterArgTypes, pred,
667                                                   /*withElseRegion=*/true);
668 
669       // Pred == true, therefore pack newly selected ivs, val and init flag back
670       // to iter_args and return.
671       {
672         OpBuilder ifPredThenB = ifPred.getThenBodyBuilder(b->getListener());
673         ifPredThenB.create<scf::YieldOp>(
674             loc, IterArgs{operandIvs, operandElem, trueI1}.toVector());
675       }
676 
677       // Pred == false, therefore return old iter_args.
678       {
679         OpBuilder ifPredElseB = ifPred.getElseBodyBuilder(b->getListener());
680         ifPredElseB.create<scf::YieldOp>(loc, ivsValFlag->toVector());
681       }
682 
683       ifInitThenB.create<scf::YieldOp>(loc, ifPred.getResults());
684     }
685     // Init == false, i.e. only pad was visited before and this is the first
686     // element in the boundaries of the operand.
687     {
688       OpBuilder ifInitElseB = ifInit.getElseBodyBuilder(b->getListener());
689 
690       ifInitElseB.create<scf::YieldOp>(
691           loc, IterArgs{operandIvs, operandElem, trueI1}.toVector());
692     }
693     return ifInit.getResults();
694   }
695 };
696 
697 struct LhloLegalizeToParallelLoopsPass
698     : public LhloLegalizeToParallelLoopsPassBase<
699           LhloLegalizeToParallelLoopsPass> {
getDependentDialectsmlir::lmhlo::__anon91b1dc7e0111::LhloLegalizeToParallelLoopsPass700   void getDependentDialects(DialectRegistry& registry) const override {
701     registry.insert<arith::ArithmeticDialect, func::FuncDialect,
702                     memref::MemRefDialect, scf::SCFDialect>();
703   }
704 
runOnOperationmlir::lmhlo::__anon91b1dc7e0111::LhloLegalizeToParallelLoopsPass705   void runOnOperation() override {
706     auto func = getOperation();
707 
708     RewritePatternSet patterns(&getContext());
709     // clang-format off
710     patterns.add<
711         ReduceOpConverter,
712         ReduceWindowOpConverter,
713         SelectAndScatterOpConverter
714       >(func.getContext());
715     // clang-format on
716 
717     ConversionTarget target(getContext());
718     target.addLegalDialect<arith::ArithmeticDialect, linalg::LinalgDialect,
719                            memref::MemRefDialect, func::FuncDialect,
720                            scf::SCFDialect, LmhloDialect>();
721     target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
722                         lmhlo::SelectAndScatterOp>();
723 
724     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
725       signalPassFailure();
726     }
727   }
728 };
729 }  // namespace
730 
731 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeLhloToParallelLoopsPass()732 createLegalizeLhloToParallelLoopsPass() {
733   return std::make_unique<LhloLegalizeToParallelLoopsPass>();
734 }
735 
736 }  // namespace lmhlo
737 }  // namespace mlir
738