1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Shape/IR/Shape.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/MLIRContext.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/Types.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 
31 namespace mlir {
32 namespace mhlo {
33 
34 namespace {
35 
36 // Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
37 // 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
38 // a static broadcast.
broadcastToFeatureDim(Location loc,RankedTensorType resultType,Value value1d,Value shapeValue,int64_t featureDim,PatternRewriter & rewriter)39 Value broadcastToFeatureDim(Location loc, RankedTensorType resultType,
40                             Value value1d, Value shapeValue, int64_t featureDim,
41                             PatternRewriter& rewriter) {  // NOLINT
42   auto dimsType = RankedTensorType::get({1}, rewriter.getIntegerType(64));
43   auto dims = DenseIntElementsAttr::get(dimsType, {featureDim});
44   if (shapeValue) {
45     return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
46         loc, resultType, value1d, shapeValue, dims);
47   }
48   assert(resultType.hasStaticShape());
49   return rewriter.create<mhlo::BroadcastInDimOp>(loc, resultType, value1d,
50                                                  dims);
51 }
52 
53 // Get the shape of operand, assuming it is a dynamic shape with static rank.
getShapeValue(Location loc,Value operand,PatternRewriter & rewriter)54 Value getShapeValue(Location loc, Value operand,
55                     PatternRewriter &rewriter) {  // NOLINT
56   RankedTensorType resultType = operand.getType().dyn_cast<RankedTensorType>();
57   return rewriter.create<mlir::shape::ShapeOfOp>(
58       loc,
59       RankedTensorType::get({resultType.getRank()}, rewriter.getIndexType()),
60       operand);
61 }
62 
materializeEpsilon(Operation * op,FloatAttr epsilonAttr,FloatType fpType,Value broadcastTo,RankedTensorType broadcastToType,PatternRewriter & rewriter)63 Value materializeEpsilon(Operation *op, FloatAttr epsilonAttr, FloatType fpType,
64                          Value broadcastTo, RankedTensorType broadcastToType,
65                          PatternRewriter &rewriter) {  // NOLINT
66   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
67   if (epsilonAttr.getType() != fpType) {
68     // Need to convert.
69     bool losesInfo;
70     APFloat epsilonFloat = epsilonAttr.getValue();
71     auto status = epsilonFloat.convert(
72         fpType.getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo);
73     if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
74       op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
75                            "type: opStatus = "
76                         << static_cast<int>(status);
77       return nullptr;
78     }
79     if (losesInfo) {
80       op->emitWarning("Conversion of epsilon loses precision");
81     }
82     epsilonAttr = b.getFloatAttr(fpType, epsilonFloat);
83   }
84 
85   auto scalarType = RankedTensorType::get({}, fpType);
86   auto epsilonTensorAttr =
87       DenseElementsAttr::get(scalarType, {epsilonAttr.cast<Attribute>()});
88   Value epsilon = b.create<mhlo::ConstantOp>(epsilonTensorAttr);
89   auto dimsType = RankedTensorType::get({0}, b.getIntegerType(64));
90   auto dims = DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
91   if (broadcastToType.hasStaticShape()) {
92     return b.create<mhlo::BroadcastInDimOp>(broadcastToType, epsilon,
93                                             /*broadcast_dims=*/dims);
94   }
95   Value shapeValue = getShapeValue(op->getLoc(), broadcastTo, rewriter);
96   return b.createOrFold<mhlo::DynamicBroadcastInDimOp>(broadcastToType, epsilon,
97                                                        shapeValue,
98                                                        /*broadcast_dims=*/dims);
99 }
100 
101 class UnfuseBatchNormInferencePattern
102     : public OpRewritePattern<mhlo::BatchNormInferenceOp> {
103  public:
104   using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
105 
matchAndRewrite(mhlo::BatchNormInferenceOp bnOp,PatternRewriter & rewriter) const106   LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bnOp,
107                                 PatternRewriter& rewriter) const override {
108     // Enforce type invariants.
109     // Note that we deduce the actual element type from the variance,
110     // which should not be subject to quantization at a higher level.
111     auto inputType = bnOp.operand().getType().dyn_cast<RankedTensorType>();
112     auto varianceType = bnOp.variance().getType().dyn_cast<RankedTensorType>();
113     if (!inputType || !varianceType) {
114       return failure();
115     }
116     auto fpType = varianceType.getElementType().dyn_cast<FloatType>();
117     if (!fpType) {
118       return failure();
119     }
120     int64_t featureDim = bnOp.feature_index();
121 
122     // Add epsilon to the variance and sqrt to get stddev:
123     // stddev = sqrt(variance + epsilon)
124     auto epsilon =
125         materializeEpsilon(bnOp.getOperation(), bnOp.epsilonAttr(), fpType,
126                            bnOp.variance(), varianceType, rewriter);
127     if (!epsilon) {
128       return failure();
129     }
130     Value stddev =
131         rewriter.create<mhlo::AddOp>(bnOp.getLoc(), bnOp.variance(), epsilon);
132     stddev = rewriter.create<mhlo::SqrtOp>(bnOp.getLoc(), stddev);
133 
134     // Broadcast all terms.
135     Value shapeValue;
136     if (!inputType.hasStaticShape()) {
137       shapeValue = getShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter);
138     }
139     auto broadcastScale =
140         broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.scale(),
141                               shapeValue, featureDim, rewriter);
142     auto broadcastOffset =
143         broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.offset(),
144                               shapeValue, featureDim, rewriter);
145     auto broadcastMean =
146         broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.mean(), shapeValue,
147                               featureDim, rewriter);
148     auto broadcastStddev = broadcastToFeatureDim(
149         bnOp.getLoc(), inputType, stddev, shapeValue, featureDim, rewriter);
150 
151     // Compute:
152     // scale * (input - mean) / stddev + offset
153     Value result = rewriter.create<mhlo::SubtractOp>(
154         bnOp.getLoc(), bnOp.operand(), broadcastMean);
155     result =
156         rewriter.create<mhlo::MulOp>(bnOp.getLoc(), result, broadcastScale);
157     result =
158         rewriter.create<mhlo::DivOp>(bnOp.getLoc(), result, broadcastStddev);
159     rewriter.replaceOpWithNewOp<mhlo::AddOp>(bnOp, result, broadcastOffset);
160 
161     return success();
162   }
163 };
164 
165 // Create "mhlo.reduce", "operand" is reduce input and "zero" is init value,
166 // reduce sum from operand to operand[feature_index].
createReduce(Location loc,Value operand,Value zero,SmallVector<int64_t> & reduceDims,int64_t featureIndex,PatternRewriter & rewriter)167 Value createReduce(Location loc, Value operand, Value zero,
168                    SmallVector<int64_t>& reduceDims, int64_t featureIndex,
169                    PatternRewriter& rewriter) {
170   auto operandType = operand.getType().cast<RankedTensorType>();
171   Type reduceResultType = RankedTensorType::get(
172       {operandType.getDimSize(featureIndex)}, operandType.getElementType());
173   mhlo::ReduceOp reduce =
174       rewriter.create<mhlo::ReduceOp>(loc, reduceResultType, operand, zero,
175                                       rewriter.getI64TensorAttr(reduceDims));
176 
177   // setup "mhlo.reduce"'s body
178   Region& region = reduce.body();
179   Block& block = region.emplaceBlock();
180   RankedTensorType blockArgumentType =
181       RankedTensorType::get({}, operandType.getElementType());
182   block.addArgument(blockArgumentType, loc);
183   block.addArgument(blockArgumentType, loc);
184   auto* firstArgument = block.args_begin();
185   auto secondArgument = block.args_rbegin();
186   {
187     OpBuilder::InsertionGuard guard(rewriter);
188     rewriter.setInsertionPointToStart(&block);
189     Value addResult =
190         rewriter.create<mhlo::AddOp>(loc, *firstArgument, *secondArgument);
191     rewriter.create<mhlo::ReturnOp>(loc, addResult);
192   }
193 
194   return reduce.getResult(0);
195 }
196 
197 // Calculate total reduce size, assuming it is a dynamic shape with static rank.
198 // Reduce from operand to operand[feature_index]/scale
calculateReduceSize(Operation * op,Value operand,RankedTensorType operandType,Value scale,RankedTensorType scaleType,int64_t featureIndex,PatternRewriter & rewriter)199 Value calculateReduceSize(Operation *op, Value operand,
200                           RankedTensorType operandType, Value scale,
201                           RankedTensorType scaleType, int64_t featureIndex,
202                           PatternRewriter &rewriter) {
203   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
204   Type indexType = b.getIndexType();
205   if (!operandType.hasStaticShape()) {
206     // the "operand" has dynamic shape with static rank
207     Value operandShape = getShapeValue(op->getLoc(), operand, rewriter);
208     Value scaleShape = getShapeValue(op->getLoc(), scale, rewriter);
209     Value operandTotalSize =
210         b.create<shape::NumElementsOp>(indexType, operandShape);
211     Value scaleTotalSize =
212         b.create<shape::NumElementsOp>(indexType, scaleShape);
213     Value reduceSize =
214         b.create<shape::DivOp>(indexType, operandTotalSize, scaleTotalSize);
215     reduceSize = b.create<arith::IndexCastOp>(b.getI64Type(), reduceSize);
216     reduceSize = b.create<tensor::FromElementsOp>(reduceSize);
217     reduceSize = b.create<mhlo::ConvertOp>(
218         RankedTensorType::get({1}, operandType.getElementType()), reduceSize);
219     reduceSize = b.create<mhlo::ReshapeOp>(
220         RankedTensorType::get({}, operandType.getElementType()), reduceSize);
221     return b.createOrFold<mhlo::DynamicBroadcastInDimOp>(
222         scaleType, reduceSize, scaleShape, b.getI64TensorAttr({}));
223   }
224 
225   // the "operand" has static shape
226   int64_t reduceDimsSize = 1;
227   for (int64_t i = 0, e = operandType.getRank(); i < e; i++) {
228     if (i != featureIndex) {
229       reduceDimsSize *= operandType.getDimSize(i);
230     }
231   }
232   llvm::APFloat floatValue(static_cast<double>(reduceDimsSize));
233   bool losesInfo;
234   floatValue.convert(
235       scaleType.getElementType().cast<FloatType>().getFloatSemantics(),
236       APFloat::rmNearestTiesToEven, &losesInfo);
237   if (losesInfo) {
238     op->emitWarning("Conversion of reduce_dims_size loses precision");
239   }
240   Value reduceSize = b.create<mhlo::ConstantOp>(
241       DenseFPElementsAttr::get(scaleType, floatValue));
242   return reduceSize;
243 }
244 
245 // BatchNormTraining(X, scale, offset) =
246 //    ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale + offset.
247 class UnfuseBatchNormTrainingPattern
248     : public OpRewritePattern<mhlo::BatchNormTrainingOp> {
249  public:
250   using OpRewritePattern<mhlo::BatchNormTrainingOp>::OpRewritePattern;
251 
matchAndRewrite(mhlo::BatchNormTrainingOp bnOp,PatternRewriter & rewriter) const252   LogicalResult matchAndRewrite(mhlo::BatchNormTrainingOp bnOp,
253                                 PatternRewriter& rewriter) const override {
254     auto operandType = bnOp.operand().getType().dyn_cast<RankedTensorType>();
255     auto scaleType = bnOp.scale().getType().dyn_cast<RankedTensorType>();
256     if (!operandType || !scaleType) {
257       return failure();
258     }
259     auto fpType = operandType.getElementType().dyn_cast<FloatType>();
260     if (!fpType) {
261       return failure();
262     }
263     int64_t featureIndex = bnOp.feature_index();
264     SmallVector<int64_t> dimensionsWithoutFeature;
265     for (int64_t i = 0, e = operandType.getRank(); i < e; i++) {
266       if (i != featureIndex) {
267         dimensionsWithoutFeature.push_back(i);
268       }
269     }
270 
271     // zero constant
272     Value constZero = rewriter.create<mhlo::ConstantOp>(
273         bnOp.getLoc(),
274         DenseFPElementsAttr::get(RankedTensorType::get({}, fpType),
275                                  APFloat::getZero(fpType.getFloatSemantics())));
276     // epsilon
277     auto epsilon =
278         materializeEpsilon(bnOp.getOperation(), bnOp.epsilonAttr(), fpType,
279                            bnOp.scale(), scaleType, rewriter);
280     if (!epsilon) {
281       return failure();
282     }
283     // reduce size constant
284     Value reduceSize =
285         calculateReduceSize(bnOp.getOperation(), bnOp.operand(), operandType,
286                             bnOp.scale(), scaleType, featureIndex, rewriter);
287     if (!reduceSize) {
288       return failure();
289     }
290     // Sum[X]
291     Value sum = createReduce(bnOp.getLoc(), bnOp.operand(), constZero,
292                              dimensionsWithoutFeature, featureIndex, rewriter);
293     // X^2
294     Value operandSquare = rewriter.create<mhlo::MulOp>(
295         bnOp.getLoc(), bnOp.operand(), bnOp.operand());
296     // Sum[X^2]
297     Value squareSum =
298         createReduce(bnOp.getLoc(), operandSquare, constZero,
299                      dimensionsWithoutFeature, featureIndex, rewriter);
300     // E[X]
301     Value mean = rewriter.create<mhlo::DivOp>(bnOp.getLoc(), sum, reduceSize);
302     // E[X^2]
303     Value squareMean =
304         rewriter.create<mhlo::DivOp>(bnOp.getLoc(), squareSum, reduceSize);
305     // E^2[X]
306     Value meanSquare = rewriter.create<mhlo::MulOp>(bnOp.getLoc(), mean, mean);
307     // Var[X]
308     Value var = rewriter.create<mhlo::SubtractOp>(bnOp.getLoc(), squareMean,
309                                                   meanSquare);
310     // Var[X] + epsilon
311     Value varAddEpsilon =
312         rewriter.create<mhlo::AddOp>(bnOp.getLoc(), var, epsilon);
313     // Sqrt(Var[X] + epsilon)
314     Value sqrtVar = rewriter.create<mhlo::SqrtOp>(bnOp.getLoc(), varAddEpsilon);
315 
316     Value shapeValue;
317     if (!operandType.hasStaticShape()) {
318       shapeValue = getShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter);
319     }
320     // X - E[X]
321     Value meanBroadcast = broadcastToFeatureDim(
322         bnOp.getLoc(), operandType, mean, shapeValue, featureIndex, rewriter);
323     Value operandMinusMean = rewriter.create<mhlo::SubtractOp>(
324         bnOp.getLoc(), bnOp.operand(), meanBroadcast);
325     // (X - E[X]) / Sqrt(Var[X] + epsilon)
326     Value sqrtVarBroadcast =
327         broadcastToFeatureDim(bnOp.getLoc(), operandType, sqrtVar, shapeValue,
328                               featureIndex, rewriter);
329     Value normalized = rewriter.create<mhlo::DivOp>(
330         bnOp.getLoc(), operandMinusMean, sqrtVarBroadcast);
331 
332     // ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale
333     Value scaleBroadcast =
334         broadcastToFeatureDim(bnOp.getLoc(), operandType, bnOp.scale(),
335                               shapeValue, featureIndex, rewriter);
336     Value scaledNormalized =
337         rewriter.create<mhlo::MulOp>(bnOp.getLoc(), normalized, scaleBroadcast);
338     // ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale + offset.
339     Value offsetBroadcast =
340         broadcastToFeatureDim(bnOp.getLoc(), operandType, bnOp.offset(),
341                               shapeValue, featureIndex, rewriter);
342     Value shiftedNormalized = rewriter.create<mhlo::AddOp>(
343         bnOp.getLoc(), scaledNormalized, offsetBroadcast);
344 
345     // results
346     SmallVector<Value> results = {shiftedNormalized, mean, var};
347     rewriter.replaceOp(bnOp, results);
348 
349     return success();
350   }
351 };
352 
353 }  // namespace
354 
355 // Populates conversion patterns to unfuse batch normalization operations.
356 // In combination with marking such ops as illegal, this allows backends that
357 // do not have special support for fused batchnorm to use simpler arithmetic
358 // primitives.
populateUnfuseBatchNormInferencePattern(MLIRContext * context,RewritePatternSet * patterns)359 void populateUnfuseBatchNormInferencePattern(MLIRContext *context,
360                                              RewritePatternSet *patterns) {
361   patterns->add<UnfuseBatchNormInferencePattern>(context);
362 }
363 
populateUnfuseBatchNormTrainingPattern(MLIRContext * context,RewritePatternSet * patterns)364 void populateUnfuseBatchNormTrainingPattern(MLIRContext *context,
365                                             RewritePatternSet *patterns) {
366   patterns->add<UnfuseBatchNormTrainingPattern>(context);
367 }
368 
369 }  // namespace mhlo
370 }  // namespace mlir
371