1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Enable the use of M_* math constants.
17 // NOTE: this must be first in the file to ensure that if cmath is transitively
18 // included by any other header it has the define set on first processing.
19 // https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
20 #define _USE_MATH_DEFINES
21 #include <algorithm>
22 #include <cmath>
23 #include <numeric>
24 #include <vector>
25
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
30 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
31 #include "mlir-hlo/utils/broadcast_utils.h"
32 #include "mlir-hlo/utils/hlo_utils.h"
33 #include "mlir/Dialect/Complex/IR/Complex.h"
34 #include "mlir/Dialect/Func/IR/FuncOps.h"
35 #include "mlir/Dialect/SCF/IR/SCF.h"
36 #include "mlir/Dialect/Shape/IR/Shape.h"
37 #include "mlir/Dialect/Tensor/IR/Tensor.h"
38 #include "mlir/IR/Attributes.h"
39 #include "mlir/IR/BuiltinTypes.h"
40 #include "mlir/IR/ImplicitLocOpBuilder.h"
41 #include "mlir/IR/MLIRContext.h"
42 #include "mlir/IR/OperationSupport.h"
43 #include "mlir/IR/PatternMatch.h"
44 #include "mlir/Transforms/DialectConversion.h"
45
46 namespace mlir {
47 namespace chlo {
48 namespace {
49
50 struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
51 using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertConstantLikeOp52 LogicalResult matchAndRewrite(
53 ConstantLikeOp op, OpAdaptor adaptor,
54 ConversionPatternRewriter &rewriter) const override {
55 auto resultTy = op.getType().cast<ShapedType>();
56
57 // Unranked uses are not supported.
58 if (!resultTy.hasRank()) return failure();
59
60 // Lower to MHLO constant if statically shaped.
61 if (resultTy.hasStaticShape()) {
62 auto complexAttr = op.value().dyn_cast<complex::NumberAttr>();
63 auto attr = complexAttr
64 ? DenseElementsAttr::get(resultTy, complexAttr.getValue())
65 : DenseElementsAttr::get(resultTy, op.value());
66 rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, attr);
67 return success();
68 }
69
70 // Lower to broadcasted constant.
71 auto loc = op.getLoc();
72 Value constant = rewriter.create<mhlo::ConstantOp>(loc, op.value());
73 Value shape = rewriter.create<shape::ShapeOfOp>(loc, adaptor.operand());
74 rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
75 op, resultTy, constant, shape, rewriter.getI64TensorAttr({}));
76 return success();
77 }
78 };
79
80 template <typename FTy>
materializeChebyshevPolynomialApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,ArrayRef<FTy> coefficients)81 Value materializeChebyshevPolynomialApproximation(
82 ConversionPatternRewriter &rewriter, Location loc, Value x,
83 ArrayRef<FTy> coefficients) {
84 Value b0 = chlo::getConstantLike(rewriter, loc, 0.0, x);
85 Value b1 = chlo::getConstantLike(rewriter, loc, 0.0, x);
86 Value b2 = chlo::getConstantLike(rewriter, loc, 0.0, x);
87 for (FTy c : coefficients) {
88 b2 = b1;
89 b1 = b0;
90 b0 = rewriter.create<mhlo::MulOp>(loc, x.getType(), x, b1);
91 b0 = rewriter.create<mhlo::SubtractOp>(loc, x.getType(), b0, b2);
92 b0 = rewriter.create<mhlo::AddOp>(
93 loc, x.getType(), b0, chlo::getConstantLike(rewriter, loc, c, x));
94 }
95 Value result = rewriter.create<mhlo::SubtractOp>(loc, x.getType(), b0, b2);
96 result = rewriter.create<mhlo::MulOp>(
97 loc, x.getType(), result, chlo::getConstantLike(rewriter, loc, 0.5, x));
98 return result;
99 }
100
101 template <typename FTy>
materializeBesselI1eApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,ArrayRef<FTy> kI1eCoeffsA,ArrayRef<FTy> kI1eCoeffsB)102 Value materializeBesselI1eApproximation(ConversionPatternRewriter &rewriter,
103 Location loc, Value x,
104 ArrayRef<FTy> kI1eCoeffsA,
105 ArrayRef<FTy> kI1eCoeffsB) {
106 Value z = rewriter.create<mhlo::AbsOp>(loc, x);
107 Value half = chlo::getConstantLike(rewriter, loc, 0.5, x);
108 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
109 Value thirtyTwo = chlo::getConstantLike(rewriter, loc, 32.0, x);
110 Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
111
112 Value tmp = rewriter.create<mhlo::MulOp>(loc, half, z);
113 tmp = rewriter.create<mhlo::SubtractOp>(loc, tmp, two);
114
115 Value xLe8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
116 kI1eCoeffsA);
117 xLe8 = rewriter.create<mhlo::MulOp>(loc, z, xLe8);
118
119 tmp = rewriter.create<mhlo::DivOp>(loc, thirtyTwo, z);
120 tmp = rewriter.create<mhlo::SubtractOp>(loc, tmp, two);
121 Value xGt8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
122 kI1eCoeffsB);
123 xGt8 = rewriter.create<mhlo::DivOp>(loc, xGt8,
124 rewriter.create<mhlo::SqrtOp>(loc, z));
125
126 Value isLe8 = rewriter.create<mhlo::CompareOp>(loc, z, eight,
127 mhlo::ComparisonDirection::LE);
128
129 Value select = rewriter.create<mhlo::SelectOp>(loc, isLe8, xLe8, xGt8);
130 return rewriter.create<mhlo::MulOp>(
131 loc, rewriter.create<mhlo::SignOp>(loc, x), select);
132 }
133
materializeBesselI1eApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)134 Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
135 Location loc, ValueRange args) {
136 Value x = args.front();
137 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
138 "expect f32 element type");
139 const float kI1eCoeffsA[] = {
140 9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
141 2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
142 3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
143 4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
144 5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
145 4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
146 2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
147 1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
148 2.52587186443633654823E-1f};
149
150 const float kI1eCoeffsB[] = {
151 -3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
152 -2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
153 -1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
154 7.78576235018280120474E-1f};
155
156 return materializeBesselI1eApproximation<float>(rewriter, loc, x, kI1eCoeffsA,
157 kI1eCoeffsB);
158 }
159
materializeBesselI1eApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)160 Value materializeBesselI1eApproximationF64(ConversionPatternRewriter &rewriter,
161 Location loc, ValueRange args) {
162 Value x = args.front();
163 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
164 "expect f64 element type");
165
166 const double kI1eCoeffsA[] = {
167 2.77791411276104639959E-18, -2.11142121435816608115E-17,
168 1.55363195773620046921E-16, -1.10559694773538630805E-15,
169 7.60068429473540693410E-15, -5.04218550472791168711E-14,
170 3.22379336594557470981E-13, -1.98397439776494371520E-12,
171 1.17361862988909016308E-11, -6.66348972350202774223E-11,
172 3.62559028155211703701E-10, -1.88724975172282928790E-9,
173 9.38153738649577178388E-9, -4.44505912879632808065E-8,
174 2.00329475355213526229E-7, -8.56872026469545474066E-7,
175 3.47025130813767847674E-6, -1.32731636560394358279E-5,
176 4.78156510755005422638E-5, -1.61760815825896745588E-4,
177 5.12285956168575772895E-4, -1.51357245063125314899E-3,
178 4.15642294431288815669E-3, -1.05640848946261981558E-2,
179 2.47264490306265168283E-2, -5.29459812080949914269E-2,
180 1.02643658689847095384E-1, -1.76416518357834055153E-1,
181 2.52587186443633654823E-1};
182
183 const double kI1eCoeffsB[] = {
184 7.51729631084210481353E-18, 4.41434832307170791151E-18,
185 -4.65030536848935832153E-17, -3.20952592199342395980E-17,
186 2.96262899764595013876E-16, 3.30820231092092828324E-16,
187 -1.88035477551078244854E-15, -3.81440307243700780478E-15,
188 1.04202769841288027642E-14, 4.27244001671195135429E-14,
189 -2.10154184277266431302E-14, -4.08355111109219731823E-13,
190 -7.19855177624590851209E-13, 2.03562854414708950722E-12,
191 1.41258074366137813316E-11, 3.25260358301548823856E-11,
192 -1.89749581235054123450E-11, -5.58974346219658380687E-10,
193 -3.83538038596423702205E-9, -2.63146884688951950684E-8,
194 -2.51223623787020892529E-7, -3.88256480887769039346E-6,
195 -1.10588938762623716291E-4, -9.76109749136146840777E-3,
196 7.78576235018280120474E-1};
197
198 return materializeBesselI1eApproximation<double>(rewriter, loc, x,
199 kI1eCoeffsA, kI1eCoeffsB);
200 }
201
materializeWithUpcast(ConversionPatternRewriter & rewriter,Location loc,ValueRange args,FloatType minPrecisionTy,Value callback (ConversionPatternRewriter &,Location,ValueRange))202 Value materializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
203 ValueRange args, FloatType minPrecisionTy,
204 Value callback(ConversionPatternRewriter &,
205 Location, ValueRange)) {
206 auto originalTy = getElementTypeOrSelf(args.front().getType());
207 auto floatOriginalTy = originalTy.dyn_cast<FloatType>();
208 bool needsUpcast =
209 floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth();
210
211 // Upcast arguments if necessary.
212 llvm::SmallVector<Value, 2> castedArgs;
213 if (needsUpcast) {
214 for (Value a : args) {
215 castedArgs.push_back(
216 rewriter.create<mhlo::ConvertOp>(loc, a, minPrecisionTy));
217 }
218 args = castedArgs;
219 }
220
221 Value result = callback(rewriter, loc, args);
222
223 // Cast back if necessary.
224 if (needsUpcast) {
225 result = rewriter.create<mhlo::ConvertOp>(loc, result, originalTy);
226 }
227
228 return result;
229 }
230
231 struct ConvertBesselI1eOp : public OpConversionPattern<BesselI1eOp> {
232 using OpConversionPattern<BesselI1eOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertBesselI1eOp233 LogicalResult matchAndRewrite(
234 BesselI1eOp op, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter) const override {
236 Location loc = op.getLoc();
237 Value x = adaptor.operand();
238 Type ty = x.getType().cast<ShapedType>().getElementType();
239
240 // For now, we support only f64, f32, f16 and bf16.
241 // See https://www.tensorflow.org/api_docs/python/tf/math/bessel_i1e
242 if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16())
243 return failure();
244
245 if (ty.isF64()) {
246 rewriter.replaceOp(
247 op, materializeBesselI1eApproximationF64(rewriter, loc, x));
248 return success();
249 }
250
251 rewriter.replaceOp(
252 op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
253 rewriter.getF32Type(),
254 &materializeBesselI1eApproximationF32));
255 return success();
256 }
257 };
258
259 template <typename FTy>
materializePolynomialApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,ArrayRef<FTy> coefficients)260 Value materializePolynomialApproximation(ConversionPatternRewriter &rewriter,
261 Location loc, Value x,
262 ArrayRef<FTy> coefficients) {
263 Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
264 for (FTy c : coefficients) {
265 poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
266 poly = rewriter.create<mhlo::AddOp>(
267 loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
268 }
269 return poly;
270 }
271
272 // Precondition is |x| >= 1. Use erf approximation, otherwise.
273 //
274 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
275 // argument and derive the final approximation for all |x| >= 1.
276 // This implementation is based on Cephes.
materializeErfcApproximationF64ForMagnituteGeOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)277 Value materializeErfcApproximationF64ForMagnituteGeOne(
278 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
279 Value x = args.front();
280 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
281 "expect f64 element type");
282 const double kMaxlog = 7.09782712893383996843E2;
283 const double kErfcPCoefficients[] = {
284 2.46196981473530512524E-10, 5.64189564831068821977E-1,
285 7.46321056442269912687E0, 4.86371970985681366614E1,
286 1.96520832956077098242E2, 5.26445194995477358631E2,
287 9.34528527171957607540E2, 1.02755188689515710272E3,
288 5.57535335369399327526E2};
289 const double kErfcQCoefficients[] = {
290 1.00000000000000000000E0, 1.32281951154744992508E1,
291 8.67072140885989742329E1, 3.54937778887819891062E2,
292 9.75708501743205489753E2, 1.82390916687909736289E3,
293 2.24633760818710981792E3, 1.65666309194161350182E3,
294 5.57535340817727675546E2};
295 const double kErfcRCoefficients[] = {
296 5.64189583547755073984E-1, 1.27536670759978104416E0,
297 5.01905042251180477414E0, 6.16021097993053585195E0,
298 7.40974269950448939160E0, 2.97886665372100240670E0};
299 const double kErfcSCoefficients[] = {
300 1.00000000000000000000E0, 2.26052863220117276590E0,
301 9.39603524938001434673E0, 1.20489539808096656605E1,
302 1.70814450747565897222E1, 9.60896809063285878198E0,
303 3.36907645100081516050E0};
304
305 // Let z = -x^2.
306 Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
307 Value z = rewriter.create<mhlo::NegOp>(loc, xSq);
308
309 // Materialize polynomial approximation for x in [1, 8) as
310 // erfc(x) = exp(z) P(|x|) / Q(|x|).
311 Value expZ = rewriter.create<mhlo::ExpOp>(loc, z);
312 Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
313 Value polP = materializePolynomialApproximation(
314 rewriter, loc, absX, llvm::makeArrayRef(kErfcPCoefficients));
315 Value expZMulPolyP = rewriter.create<mhlo::MulOp>(loc, expZ, polP);
316 Value polQ = materializePolynomialApproximation(
317 rewriter, loc, absX, llvm::makeArrayRef(kErfcQCoefficients));
318 Value erfcApprox18 = rewriter.create<mhlo::DivOp>(loc, expZMulPolyP, polQ);
319
320 // Materialize polynomial approximation for x in >= 8 as
321 // erfc(x) exp(z) R(|x|) / S(|x|).
322 Value polR = materializePolynomialApproximation(
323 rewriter, loc, absX, llvm::makeArrayRef(kErfcRCoefficients));
324 Value expZMulPolyR = rewriter.create<mhlo::MulOp>(loc, expZ, polR);
325 Value polS = materializePolynomialApproximation(
326 rewriter, loc, absX, llvm::makeArrayRef(kErfcSCoefficients));
327 Value erfcApprox8Inf = rewriter.create<mhlo::DivOp>(loc, expZMulPolyR, polS);
328
329 // Combine polynomial approximations for x >= 1.
330 Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
331 Value absXLt8 = rewriter.create<mhlo::CompareOp>(
332 loc, absX, eight, mhlo::ComparisonDirection::LT);
333 Value erfcApprox = rewriter.create<mhlo::SelectOp>(loc, absXLt8, erfcApprox18,
334 erfcApprox8Inf);
335
336 // Clamp to prevent overflow and materialize approximation for large x as
337 // erfc(x) = 0.
338 Value zLtNegMaxlog = rewriter.create<mhlo::CompareOp>(
339 loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x),
340 mhlo::ComparisonDirection::LT);
341 Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
342 Value erfcApproxClamped =
343 rewriter.create<mhlo::SelectOp>(loc, zLtNegMaxlog, zero, erfcApprox);
344
345 // Derive approximation for x <= -1 as
346 // erfc(x) = 2 - erfc(-x).
347 // Reuse previously materialized approximations all of which take |x| as their
348 // argument.
349 Value xLtZero = rewriter.create<mhlo::CompareOp>(
350 loc, x, zero, mhlo::ComparisonDirection::LT);
351 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
352 Value twoSubErfcApproxClamped =
353 rewriter.create<mhlo::SubtractOp>(loc, two, erfcApproxClamped);
354 return rewriter.create<mhlo::SelectOp>(loc, xLtZero, twoSubErfcApproxClamped,
355 erfcApproxClamped);
356 }
357
358 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
359 // This implementation is based on Cephes.
materializeErfApproximationF64ForMagnituteLeOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)360 Value materializeErfApproximationF64ForMagnituteLeOne(
361 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
362 Value x = args.front();
363 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
364 "expect f64 element type");
365 const double kErfTCoefficients[] = {
366 9.60497373987051638749E0, 9.00260197203842689217E1,
367 2.23200534594684319226E3, 7.00332514112805075473E3,
368 5.55923013010394962768E4};
369 const double kErfUCoefficients[] = {
370 1.00000000000000000000E0, 3.35617141647503099647E1,
371 5.21357949780152679795E2, 4.59432382970980127987E3,
372 2.26290000613890934246E4, 4.92673942608635921086E4};
373
374 // Materialize polynomial approximation for |x| <= 1 as
375 // erf(x) = x T(x^2) / U(x^2).
376 Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
377 Value polyT = materializePolynomialApproximation(
378 rewriter, loc, xSq, llvm::makeArrayRef(kErfTCoefficients));
379 Value xMulPolyT = rewriter.create<mhlo::MulOp>(loc, x, polyT);
380 Value polyU = materializePolynomialApproximation(
381 rewriter, loc, xSq, llvm::makeArrayRef(kErfUCoefficients));
382 return rewriter.create<mhlo::DivOp>(loc, xMulPolyT, polyU);
383 }
384
385 // This implementation is based on Cephes.
materializeErfApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)386 Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter,
387 Location loc, ValueRange args) {
388 Value x = args.front();
389 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
390 "expect f64 element type");
391
392 // Rely on erf approximation for |x| < 1
393 // erf(x) = erf_approx(x)
394 Value erfApprox =
395 materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
396
397 // Rely on erfc approximation for |x| >= 1 and materialize erf as
398 // erf(x) = 1 - erfc_approx(x)
399 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
400 Value erfcApprox =
401 materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
402 Value erfcBasedApprox =
403 rewriter.create<mhlo::SubtractOp>(loc, one, erfcApprox);
404
405 // Materialize approximation selection based on argument.
406 Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
407 Value absXLtOne = rewriter.create<mhlo::CompareOp>(
408 loc, absX, one, mhlo::ComparisonDirection::LT);
409 return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfApprox,
410 erfcBasedApprox);
411 }
412
materializeErfcApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)413 Value materializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
414 Location loc, ValueRange args) {
415 Value x = args.front();
416 assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
417 "expect f64 element type");
418
419 // Rely on erfc approximation for |x| >= 1
420 // erfc(x) = erfc_approx(x)
421 Value erfcApprox =
422 materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
423
424 // Rely on erf approximation for |x| < 1 and materialize erfc as
425 // erfc(x) = 1 - erf_approx(x)
426 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
427 Value erfApprox =
428 materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
429 Value erfBasedApprox = rewriter.create<mhlo::SubtractOp>(loc, one, erfApprox);
430
431 // Materialize approximation selection based on argument.
432 Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
433 Value absXLtOne = rewriter.create<mhlo::CompareOp>(
434 loc, absX, one, mhlo::ComparisonDirection::LT);
435 return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfBasedApprox,
436 erfcApprox);
437 }
438
439 // Precondition is |x| >= 1. Use erf approximation, otherwise.
440 //
441 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
442 // argument and derive the final approximation for all |x| >= 1.
443 // This implementation is based on Cephes.
materializeErfcApproximationF32ForMagnitudeGeOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)444 Value materializeErfcApproximationF32ForMagnitudeGeOne(
445 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
446 Value x = args.front();
447 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
448 "expect f32 element type");
449 const double kMaxlog = 88.72283905206835;
450 const float kErfcPCoefficients[] = {
451 +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
452 -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
453 +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
454 };
455 const float kErfcRCoefficients[] = {
456 -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
457 +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
458 -2.820767439740514E-1, +5.641895067754075E-1,
459 };
460
461 // Let z = -x^2.
462 Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
463 Value z = rewriter.create<mhlo::NegOp>(loc, xSq);
464
465 // Materialize polynomial approximation for x >= 1 as
466 // erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2)
467 // erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2
468 Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
469 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
470 Value reciprocalXSq = rewriter.create<mhlo::DivOp>(loc, one, xSq);
471 Value expZ = rewriter.create<mhlo::ExpOp>(loc, z);
472 Value oneDivAbsX = rewriter.create<mhlo::DivOp>(loc, one, absX);
473 Value expZMulOneDivAbsX = rewriter.create<mhlo::MulOp>(loc, expZ, oneDivAbsX);
474 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
475 Value absXLtTwo = rewriter.create<mhlo::CompareOp>(
476 loc, absX, two, mhlo::ComparisonDirection::LT);
477 Value polP = materializePolynomialApproximation(
478 rewriter, loc, reciprocalXSq, llvm::makeArrayRef(kErfcPCoefficients));
479 Value polR = materializePolynomialApproximation(
480 rewriter, loc, reciprocalXSq, llvm::makeArrayRef(kErfcRCoefficients));
481 Value poly = rewriter.create<mhlo::SelectOp>(loc, absXLtTwo, polP, polR);
482 Value erfcApprox = rewriter.create<mhlo::MulOp>(loc, expZMulOneDivAbsX, poly);
483
484 // Clamp to prevent overflow and materialize approximation for large x as
485 // erfc(x) = 0.
486 Value zLtNeqMaxlog = rewriter.create<mhlo::CompareOp>(
487 loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x),
488 mhlo::ComparisonDirection::LT);
489 Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
490 Value erfcApproxClamped =
491 rewriter.create<mhlo::SelectOp>(loc, zLtNeqMaxlog, zero, erfcApprox);
492
493 // Derive approximation for x <= -1 as
494 // erfc(x) = 2 - erfc(-x).
495 // Reuse previously materialized approximations all of which take |x| as their
496 // argument.
497 Value xLtZero = rewriter.create<mhlo::CompareOp>(
498 loc, x, zero, mhlo::ComparisonDirection::LT);
499 Value twoSubErfcApprox =
500 rewriter.create<mhlo::SubtractOp>(loc, two, erfcApproxClamped);
501 return rewriter.create<mhlo::SelectOp>(loc, xLtZero, twoSubErfcApprox,
502 erfcApproxClamped);
503 }
504
505 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
506 // This implementation is based on Cephes.
materializeErfApproximationF32ForMagnitudeLeOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)507 Value materializeErfApproximationF32ForMagnitudeLeOne(
508 ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
509 Value x = args.front();
510 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
511 "expect f32 element type");
512 const float kErfTCoefficients[] = {
513 +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
514 -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
515 +1.128379165726710E+0,
516 };
517
518 // Materialize polynomial approximation for |x| <= 1 as
519 // erf(x) = x T(x^2).
520 Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
521 Value polyT = materializePolynomialApproximation(
522 rewriter, loc, xSq, llvm::makeArrayRef(kErfTCoefficients));
523 return rewriter.create<mhlo::MulOp>(loc, x, polyT);
524 }
525
526 // This is the same approximation as used in Eigen.
materializeErfApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)527 Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
528 Location loc, ValueRange args) {
529 Value x = args.front();
530 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
531 "expect f32 element type");
532 const float kAlpha[] = {
533 -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
534 -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
535 -1.60960333262415e-02f,
536 };
537 const float kBeta[] = {
538 -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
539 -7.37332916720468e-03f, -1.42647390514189e-02f,
540 };
541
542 // Clamp argument between -4 and 4.
543 Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
544 Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
545 x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
546 Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
547
548 // Materialize polynomial approximation for x in [-4, 4] as
549 // erf(x) = x * Alpha(x^2) / Beta(x^2).
550 Value alphaPoly = materializePolynomialApproximation(
551 rewriter, loc, xSq, llvm::makeArrayRef(kAlpha));
552 Value betaPoly = materializePolynomialApproximation(
553 rewriter, loc, xSq, llvm::makeArrayRef(kBeta));
554 Value xMulAlphaPoly = rewriter.create<mhlo::MulOp>(loc, x, alphaPoly);
555 return rewriter.create<mhlo::DivOp>(loc, xMulAlphaPoly, betaPoly);
556 }
557
materializeErfcApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)558 Value materializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
559 Location loc, ValueRange args) {
560 Value x = args.front();
561 assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
562 "expect f32 element type");
563
564 // Rely on erfc approximation for |x| >= 1
565 // erfc(x) = erfc_approx(x)
566 Value erfcApprox =
567 materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x);
568
569 // Rely on erf approximation for |x| < 1 and materialize erfc as
570 // erfc(x) = 1 - erf_approx(x)
571 Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
572 Value erfApprox =
573 materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x);
574 Value erfBasedApprox = rewriter.create<mhlo::SubtractOp>(loc, one, erfApprox);
575
576 // Materialize approximation selection based on argument.
577 Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
578 Value absXLtOne = rewriter.create<mhlo::CompareOp>(
579 loc, absX, one, mhlo::ComparisonDirection::LT);
580 return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfBasedApprox,
581 erfcApprox);
582 }
583
584 struct ConvertErfOp : public OpConversionPattern<ErfOp> {
585 using OpConversionPattern<ErfOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertErfOp586 LogicalResult matchAndRewrite(
587 ErfOp op, OpAdaptor adaptor,
588 ConversionPatternRewriter &rewriter) const override {
589 Location loc = op.getLoc();
590 Value x = adaptor.operand();
591 Type ty = x.getType().cast<ShapedType>().getElementType();
592
593 // For now, we support only f64, f32, f16 and bf16.
594 if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16())
595 return failure();
596
597 if (ty.isF64()) {
598 rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x));
599 return success();
600 }
601
602 rewriter.replaceOp(
603 op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
604 rewriter.getF32Type(),
605 &materializeErfApproximationF32));
606 return success();
607 }
608 };
609
610 struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
611 using OpConversionPattern<ErfcOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertErfcOp612 LogicalResult matchAndRewrite(
613 ErfcOp op, OpAdaptor adaptor,
614 ConversionPatternRewriter &rewriter) const override {
615 Location loc = op.getLoc();
616 Value x = adaptor.operand();
617 Type ty = x.getType().cast<ShapedType>().getElementType();
618
619 // For now, we support only f64, f32, f16 and bf16.
620 if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16())
621 return failure();
622
623 if (ty.isF64()) {
624 rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x));
625 return success();
626 }
627
628 rewriter.replaceOp(
629 op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
630 rewriter.getF32Type(),
631 &materializeErfcApproximationF32));
632 return success();
633 }
634 };
635
636 // Coefficients for the Lanczos approximation of the gamma function. The
637 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
638 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
639 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
640 // [7, 9] seemed to be the least sensitive to the quality of the log function.
641 // In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
642 // for a particularly inaccurate log function.
643 constexpr double kLanczosGamma = 7; // aka g
644 constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
645 constexpr std::array<double, 8> kLanczosCoefficients = {
646 676.520368121885098567009190444019, -1259.13921672240287047156078755283,
647 771.3234287776530788486528258894, -176.61502916214059906584551354,
648 12.507343278686904814458936853, -0.13857109526572011689554707,
649 9.984369578019570859563e-6, 1.50563273514931155834e-7};
650
651 // Compute the Lgamma function using Lanczos' approximation from "A Precision
652 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
653 // series B. Vol. 1:
654 // lgamma(z + 1) = (log(2) + log(pi)) / 2
655 // + (z + 1/2) * log(t(z))
656 // - t(z) + log(a(z))
657 // with t(z) = z + kLanczosGamma + 1/2
658 // a(z) = kBaseLanczosCoeff
659 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
materializeLgamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)660 Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
661 ValueRange args) {
662 // If the input is less than 0.5 use Euler's reflection formula.
663 // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
664 // Let z be
665 // z = -x if x < 1/2
666 // z = x - 1 otheriwse
667 Value x = args.front();
668 Value half = getConstantLike(rewriter, loc, 0.5, x);
669 Value needToReflect = rewriter.create<mhlo::CompareOp>(
670 loc, x, half, mhlo::ComparisonDirection::LT);
671 Value negX = rewriter.create<mhlo::NegOp>(loc, x);
672 Value one = getConstantLike(rewriter, loc, 1, x);
673 Value xSubOne = rewriter.create<mhlo::SubtractOp>(loc, x, one);
674 Value z = rewriter.create<mhlo::SelectOp>(loc, needToReflect, negX, xSubOne);
675
676 // Materialize
677 // a(z) = kBaseLanczosCoeff
678 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
679 Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
680 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
681 Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
682 Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
683 Value quotient = rewriter.create<mhlo::DivOp>(
684 loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, oneBasedIndex));
685 a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
686 }
687
688 // To improve accuracy on platforms with less-precise log implementations,
689 // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
690 // device.
691 // Materialize as
692 // log(t) = log(kLanczosGamma + 1/2 + z)
693 // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
694 Value lanczosPlusHalf =
695 getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
696 Value t = rewriter.create<mhlo::AddOp>(loc, lanczosPlusHalf, z);
697 Value logTerm =
698 getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
699 Value log1pTerm = rewriter.create<mhlo::Log1pOp>(
700 loc, rewriter.create<mhlo::DivOp>(loc, z, lanczosPlusHalf));
701 Value logT = rewriter.create<mhlo::AddOp>(loc, logTerm, log1pTerm);
702
703 // Note that t(z) may be large and we need to be careful not to overflow to
704 // infinity in the relevant term
705 // r = (z + 1/2) * log(t(z)) - t(z).
706 // Therefore, we compute this as
707 // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
708 Value tDivLogT = rewriter.create<mhlo::DivOp>(loc, t, logT);
709 Value sum = rewriter.create<mhlo::SubtractOp>(
710 loc, rewriter.create<mhlo::AddOp>(loc, z, half), tDivLogT);
711 Value r = rewriter.create<mhlo::MulOp>(loc, sum, logT);
712
713 // Compute the final result (modulo reflection) as
714 // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
715 Value logA = rewriter.create<mhlo::LogOp>(loc, a);
716 Value lgamma = rewriter.create<mhlo::AddOp>(
717 loc,
718 rewriter.create<mhlo::AddOp>(
719 loc,
720 getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
721 r),
722 logA);
723
724 // Compute the reflected value for x < 0.5 as
725 // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
726 //
727 // The abs is needed because lgamma is the log of the absolute value of the
728 // gamma function.
729 //
730 // We have to be careful when computing the final term above. gamma(x) goes
731 // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
732 // term. The slope is large, so precision is particularly important.
733 //
734 // Because abs(sin(pi * x)) has period of 1 we can equivalently use
735 // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
736 // more numerically accurate: It doesn't overflow to inf like pi * x would and
737 // if x is an integer it evaluates to exactly 0 which is important because we
738 // then take the log of this value, and log(0) is inf.
739 //
740 // We don't have a frac(x) primitive in HLO and computing it is tricky, but
741 // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
742 // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
743 //
744 // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
745 // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
746 // [0, 1] is symmetric across the line Y=0.5.
747 //
748
749 // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
750 // pi * abs_frac for values of abs_frac close to 1.
751 Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
752 Value absFrac = rewriter.create<mhlo::SubtractOp>(
753 loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
754 Value reduceAbsFrac = rewriter.create<mhlo::CompareOp>(
755 loc, half, absFrac, mhlo::ComparisonDirection::LT);
756 absFrac = rewriter.create<mhlo::SelectOp>(
757 loc, reduceAbsFrac, rewriter.create<mhlo::SubtractOp>(loc, one, absFrac),
758 absFrac);
759
760 // Materialize reflection.
761 Value reflectionDenom = rewriter.create<mhlo::LogOp>(
762 loc,
763 rewriter.create<mhlo::SineOp>(
764 loc, rewriter.create<mhlo::MulOp>(
765 loc, getConstantLike(rewriter, loc, M_PI, x), absFrac)));
766 Value lgammaReflection = rewriter.create<mhlo::SubtractOp>(
767 loc,
768 rewriter.create<mhlo::SubtractOp>(
769 loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
770 reflectionDenom),
771 lgamma);
772
773 // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
774 // then it "wins" and the result is +/-inf.
775 Value finiteReflectionDenom =
776 rewriter.create<mhlo::IsFiniteOp>(loc, reflectionDenom);
777 Value negReflectionDenom = rewriter.create<mhlo::NegOp>(loc, reflectionDenom);
778 lgammaReflection = rewriter.create<mhlo::SelectOp>(
779 loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom);
780
781 // Select whether or not to rely on the reflection.
782 lgamma = rewriter.create<mhlo::SelectOp>(loc, needToReflect, lgammaReflection,
783 lgamma);
784
785 // Materialize +/-inf behavior as
786 // lgamma(+/-inf) = +inf.
787 Value xIsInf = rewriter.create<chlo::IsInfOp>(loc, x);
788 return rewriter.create<mhlo::SelectOp>(
789 loc, xIsInf,
790 chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
791 lgamma);
792 }
793
794 // Express `cosh` as
795 // cosh(x) = (e^x + e^-x) / 2
796 // = e^(x + log(1/2)) + e^(-x + log(1/2))
797 //
798 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
799 //
800 // This incorrectly overflows to inf for two f32 input values, namely
801 // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
802 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
803 // we deem this acceptable.
materializeCoshApproximation(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)804 Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
805 Location loc, ValueRange operands) {
806 CoshOp::Adaptor transformed(operands);
807 Value x = transformed.operand();
808
809 Value logOneHalf =
810 rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
811 Value expAdd = rewriter.create<mhlo::ExpOp>(
812 loc, rewriter.create<mhlo::AddOp>(loc, x, logOneHalf));
813 Value expSub = rewriter.create<mhlo::ExpOp>(
814 loc, rewriter.create<mhlo::SubtractOp>(loc, logOneHalf, x));
815 return rewriter.create<mhlo::AddOp>(loc, expAdd, expSub);
816 }
817
818 struct ConvertCoshOp : public OpConversionPattern<CoshOp> {
819 using OpConversionPattern<CoshOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertCoshOp820 LogicalResult matchAndRewrite(
821 CoshOp op, OpAdaptor adaptor,
822 ConversionPatternRewriter &rewriter) const override {
823 rewriter.replaceOp(
824 op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
825 rewriter.getF32Type(),
826 &materializeCoshApproximation));
827 return success();
828 }
829 };
830
831 // Compute the Digamma function using Lanczos' approximation from "A Precision
832 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
833 // series B. Vol. 1:
834 // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
835 // with t(z) = z + kLanczosGamma + 1/2
836 // a(z) = kBaseLanczosCoeff
837 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
838 // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
materializeDigamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)839 Value materializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
840 ValueRange args) {
841 // If the input is less than 0.5 use Euler's reflection formula.
842 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
843 // Let z be
844 // z = -x if x < 1/2
845 // z = x - 1 otheriwse
846 Value x = args.front();
847 Value half = getConstantLike(rewriter, loc, 0.5, x);
848 Value needToReflect = rewriter.create<mhlo::CompareOp>(
849 loc, x, half, mhlo::ComparisonDirection::LT);
850 Value negX = rewriter.create<mhlo::NegOp>(loc, x);
851 Value one = getConstantLike(rewriter, loc, 1, x);
852 Value xSubOne = rewriter.create<mhlo::SubtractOp>(loc, x, one);
853 Value z = rewriter.create<mhlo::SelectOp>(loc, needToReflect, negX, xSubOne);
854
855 // Materialize
856 // a(z) = kBaseLanczosCoeff
857 // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
858 // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
859 Value zero = getConstantLike(rewriter, loc, 0.0, x);
860 Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
861 Value aPrime = zero;
862 for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
863 Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
864 Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
865 Value zTerm = rewriter.create<mhlo::AddOp>(loc, z, oneBasedIndex);
866 aPrime = rewriter.create<mhlo::SubtractOp>(
867 loc, aPrime,
868 rewriter.create<mhlo::DivOp>(
869 loc, coeff, rewriter.create<mhlo::MulOp>(loc, zTerm, zTerm)));
870 a = rewriter.create<mhlo::AddOp>(
871 loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, zTerm));
872 }
873
874 // To improve accuracy on platforms with less-precise log implementations,
875 // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
876 // device.
877 // Materialize as
878 // log(t) = log(kLanczosGamma + 1/2 + z)
879 // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
880 Value lanczosPlusHalf =
881 getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
882 Value t = rewriter.create<mhlo::AddOp>(loc, lanczosPlusHalf, z);
883 Value logTerm =
884 getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
885 Value log1pTerm = rewriter.create<mhlo::Log1pOp>(
886 loc, rewriter.create<mhlo::DivOp>(loc, z, lanczosPlusHalf));
887 Value logT = rewriter.create<mhlo::AddOp>(loc, logTerm, log1pTerm);
888
889 // Materialize the final result (modulo reflection) as
890 // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
891 Value aPrimeDivA = rewriter.create<mhlo::DivOp>(loc, aPrime, a);
892 Value lanczosGammaDivT = rewriter.create<mhlo::DivOp>(
893 loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
894 Value digamma = rewriter.create<mhlo::SubtractOp>(
895 loc, rewriter.create<mhlo::AddOp>(loc, logT, aPrimeDivA),
896 lanczosGammaDivT);
897
898 // We need to be careful how we compute cot(pi * input) below: For
899 // near-integral arguments, pi * input can lose precision.
900 //
901 // Input is already known to be less than 0.5 (otherwise we don't have to
902 // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
903 // increase precision of pi * x and the resulting cotangent.
904 Value reducedX = rewriter.create<mhlo::AddOp>(
905 loc, x,
906 rewriter.create<mhlo::AbsOp>(
907 loc, rewriter.create<mhlo::FloorOp>(
908 loc, rewriter.create<mhlo::AddOp>(
909 loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
910
911 // Materialize reflection for inputs less than 0.5 as
912 // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
913 // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
914 Value pi = getConstantLike(rewriter, loc, M_PI, x);
915 Value piMulReducedX = rewriter.create<mhlo::MulOp>(loc, pi, reducedX);
916 Value cos = rewriter.create<mhlo::CosineOp>(loc, piMulReducedX);
917 Value sin = rewriter.create<mhlo::SineOp>(loc, piMulReducedX);
918 Value reflection = rewriter.create<mhlo::SubtractOp>(
919 loc, digamma,
920 rewriter.create<mhlo::DivOp>(
921 loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
922
923 // Select whether or not to rely on the reflection.
924 digamma =
925 rewriter.create<mhlo::SelectOp>(loc, needToReflect, reflection, digamma);
926
927 // Digamma has poles at negative integers and zero; return nan for those.
928 Value isLeZero = rewriter.create<mhlo::CompareOp>(
929 loc, x, zero, mhlo::ComparisonDirection::LE);
930 Value isInt = rewriter.create<mhlo::CompareOp>(
931 loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
932 mhlo::ComparisonDirection::EQ);
933 Value isPole = rewriter.create<mhlo::AndOp>(loc, isLeZero, isInt);
934 return rewriter.create<mhlo::SelectOp>(
935 loc, isPole,
936 getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
937 x),
938 digamma);
939 }
940
materializeZeta(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)941 Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
942 ValueRange args) {
943 assert(args.size() == 2);
944 Value x = args[0];
945 Value q = args[1];
946 static const std::array<double, 12> kZetaCoeffs{
947 -7.1661652561756670113e18,
948 1.8152105401943546773e17,
949 -4.5979787224074726105e15,
950 1.1646782814350067249e14,
951 -2.950130727918164224e12,
952 7.47242496e10,
953 -1.8924375803183791606e9,
954 47900160.0,
955 -1209600.0,
956 30240.0,
957 -720.0,
958 12.0,
959 };
960
961 // For speed we'll always use 9 iterations for the initial series estimate,
962 // and a 12 term expansion for the Euler-Maclaurin formula.
963 Value a = q;
964 Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
965 Value negPower = zero;
966 Value negX = rewriter.create<mhlo::NegOp>(loc, x);
967 Value initialSum = rewriter.create<mhlo::PowOp>(loc, q, negX);
968 Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
969 for (int i = 0; i < 9; ++i) {
970 a = rewriter.create<mhlo::AddOp>(loc, a, one);
971 negPower = rewriter.create<mhlo::PowOp>(loc, a, negX);
972 initialSum = rewriter.create<mhlo::AddOp>(loc, initialSum, negPower);
973 }
974 a = rewriter.create<mhlo::AddOp>(loc, a, one);
975 negPower = rewriter.create<mhlo::PowOp>(loc, a, negX);
976 Value oneLikeX = chlo::getConstantLike(rewriter, loc, 1.0, x);
977 Value xMinusOne = rewriter.create<mhlo::SubtractOp>(loc, x, oneLikeX);
978 Value negPowerMulA = rewriter.create<mhlo::MulOp>(loc, negPower, a);
979 Value negPowerMulADivXMinusOne =
980 rewriter.create<mhlo::DivOp>(loc, negPowerMulA, xMinusOne);
981 Value s =
982 rewriter.create<mhlo::AddOp>(loc, initialSum, negPowerMulADivXMinusOne);
983 Value aInverseSquare = rewriter.create<mhlo::DivOp>(
984 loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
985
986 Value hornerSum = zero;
987 Value factor = one;
988 // Use Horner's rule for this.
989 // Note this differs from Cephes which does a 'naive' polynomial evaluation.
990 // Using Horner's rule allows to avoid some NaN's and Infs from happening,
991 // resulting in more numerically stable code.
992 for (int i = 0; i < 11; ++i) {
993 Value factorLhs = rewriter.create<mhlo::SubtractOp>(
994 loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
995 Value factorRhs = rewriter.create<mhlo::SubtractOp>(
996 loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
997 factor = rewriter.create<mhlo::MulOp>(loc, factorLhs, factorRhs);
998 hornerSum = rewriter.create<mhlo::MulOp>(
999 loc, factor,
1000 rewriter.create<mhlo::MulOp>(
1001 loc, aInverseSquare,
1002 rewriter.create<mhlo::AddOp>(
1003 loc, hornerSum,
1004 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
1005 }
1006 Value zeroPointFiveLikeNegPower =
1007 chlo::getConstantLike(rewriter, loc, .5, negPower);
1008 Value xDivA = rewriter.create<mhlo::DivOp>(loc, x, a);
1009 s = rewriter.create<mhlo::AddOp>(
1010 loc, s,
1011 rewriter.create<mhlo::MulOp>(
1012 loc, negPower,
1013 rewriter.create<mhlo::AddOp>(
1014 loc, zeroPointFiveLikeNegPower,
1015 rewriter.create<mhlo::MulOp>(
1016 loc, xDivA,
1017 rewriter.create<mhlo::AddOp>(
1018 loc,
1019 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
1020 a),
1021 hornerSum)))));
1022
1023 // Use the initial zeta sum without the correction term coming
1024 // from Euler-Maclaurin if it is accurate enough.
1025 Value absNegPower = rewriter.create<mhlo::AbsOp>(loc, negPower);
1026 Value absInitialSum = rewriter.create<mhlo::AbsOp>(loc, initialSum);
1027 Value output = rewriter.create<mhlo::SelectOp>(
1028 loc,
1029 rewriter.create<mhlo::CompareOp>(
1030 loc, absNegPower,
1031 rewriter.create<mhlo::MulOp>(
1032 loc, absInitialSum,
1033 chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
1034 mhlo::ComparisonDirection::LT),
1035 initialSum, s);
1036
1037 // Function is not defined for x < 1.
1038 Value nan = chlo::getConstantLike(
1039 rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
1040 output = rewriter.create<mhlo::SelectOp>(
1041 loc,
1042 rewriter.create<mhlo::CompareOp>(loc, x, oneLikeX,
1043 mhlo::ComparisonDirection::LT),
1044 nan, output);
1045
1046 // For q <= 0, x must be an integer.
1047 Value qLeZero = rewriter.create<mhlo::CompareOp>(
1048 loc, q, zero, mhlo::ComparisonDirection::LE);
1049 Value xNotInt = rewriter.create<mhlo::CompareOp>(
1050 loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
1051 mhlo::ComparisonDirection::NE);
1052 Value xDomainError = rewriter.create<mhlo::AndOp>(loc, qLeZero, xNotInt);
1053 output = rewriter.create<mhlo::SelectOp>(loc, xDomainError, nan, output);
1054
1055 // For all integer q <= 0, zeta has a pole. The limit is only defined as
1056 // +inf if x is and even integer.
1057 Value inf = chlo::getConstantLike(rewriter, loc,
1058 std::numeric_limits<double>::infinity(), x);
1059 Value qIsInt = rewriter.create<mhlo::CompareOp>(
1060 loc, q, rewriter.create<mhlo::FloorOp>(loc, q),
1061 mhlo::ComparisonDirection::EQ);
1062 Value atPole = rewriter.create<mhlo::AndOp>(loc, qLeZero, qIsInt);
1063 Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
1064 Value xIsInt = rewriter.create<mhlo::CompareOp>(
1065 loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
1066 mhlo::ComparisonDirection::EQ);
1067 Value xIsEven = rewriter.create<mhlo::CompareOp>(
1068 loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero,
1069 mhlo::ComparisonDirection::EQ);
1070 Value xIsEvenInt = rewriter.create<mhlo::AndOp>(loc, xIsInt, xIsEven);
1071 output = rewriter.create<mhlo::SelectOp>(
1072 loc, atPole, rewriter.create<mhlo::SelectOp>(loc, xIsEvenInt, inf, nan),
1073 output);
1074
1075 // For x = 1, this is the harmonic series and diverges.
1076 output = rewriter.create<mhlo::SelectOp>(
1077 loc,
1078 rewriter.create<mhlo::CompareOp>(loc, x, one,
1079 mhlo::ComparisonDirection::EQ),
1080 inf, output);
1081
1082 return output;
1083 }
1084
materializePolygamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)1085 Value materializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
1086 ValueRange args) {
1087 PolygammaOp::Adaptor transformed(args);
1088 Value n = transformed.n();
1089 Value x = transformed.x();
1090
1091 // Handle integer n > 0.
1092 Value one = getConstantLike(rewriter, loc, 1.0, x);
1093 Value two = getConstantLike(rewriter, loc, 2.0, x);
1094 Value sign = rewriter.create<mhlo::SubtractOp>(
1095 loc,
1096 rewriter.create<mhlo::MulOp>(loc, two,
1097 rewriter.create<mhlo::RemOp>(loc, n, two)),
1098 one);
1099 Value nPlusOne = rewriter.create<mhlo::AddOp>(loc, n, one);
1100 Value expLgammaNp1 = rewriter.create<mhlo::ExpOp>(
1101 loc, rewriter.create<chlo::LgammaOp>(loc, nPlusOne));
1102 Value zeta = rewriter.create<chlo::ZetaOp>(loc, nPlusOne, x);
1103 Value result = rewriter.create<mhlo::MulOp>(
1104 loc, rewriter.create<mhlo::MulOp>(loc, sign, expLgammaNp1), zeta);
1105
1106 // Handle n = 0.
1107 Value zero = getConstantLike(rewriter, loc, 0.0, x);
1108 Value nEqZero = rewriter.create<mhlo::CompareOp>(
1109 loc, n, zero, mhlo::ComparisonDirection::EQ);
1110 result = rewriter.create<mhlo::SelectOp>(
1111 loc, nEqZero, rewriter.create<chlo::DigammaOp>(loc, x), result);
1112
1113 // Check that n is a natural number. Return nan, otherwise.
1114 Value nonInt = rewriter.create<mhlo::CompareOp>(
1115 loc, n, rewriter.create<mhlo::FloorOp>(loc, n),
1116 mhlo::ComparisonDirection::NE);
1117 Value negative = rewriter.create<mhlo::CompareOp>(
1118 loc, n, zero, mhlo::ComparisonDirection::LT);
1119 Value nonNatural = rewriter.create<mhlo::OrOp>(loc, nonInt, negative);
1120 return rewriter.create<mhlo::SelectOp>(
1121 loc, nonNatural,
1122 getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
1123 x),
1124 result);
1125 }
1126
1127 struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
1128 using OpConversionPattern<LgammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertLgammaOp1129 LogicalResult matchAndRewrite(
1130 LgammaOp op, OpAdaptor adaptor,
1131 ConversionPatternRewriter &rewriter) const override {
1132 FloatType minPrecisionTy = rewriter.getF32Type();
1133 rewriter.replaceOp(
1134 op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
1135 minPrecisionTy, &materializeLgamma));
1136 return success();
1137 }
1138 };
1139
1140 struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
1141 using OpConversionPattern<DigammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertDigammaOp1142 LogicalResult matchAndRewrite(
1143 DigammaOp op, OpAdaptor adaptor,
1144 ConversionPatternRewriter &rewriter) const override {
1145 FloatType minPrecisionTy = rewriter.getF32Type();
1146 rewriter.replaceOp(
1147 op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
1148 minPrecisionTy, &materializeDigamma));
1149 return success();
1150 }
1151 };
1152
materializeNextAfter(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1153 Value materializeNextAfter(ConversionPatternRewriter &rewriter, Location loc,
1154 ValueRange operands) {
1155 NextAfterOp::Adaptor transformed(operands);
1156 Value x = transformed.x();
1157 Value y = transformed.y();
1158 auto resultTy = x.getType().cast<ShapedType>();
1159 auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth();
1160 ImplicitLocOpBuilder b(loc, rewriter);
1161 auto intTy = resultTy.clone(b.getIntegerType(bitwidth));
1162 auto xAsInt = b.create<mhlo::BitcastConvertOp>(intTy, x);
1163 auto yAsInt = b.create<mhlo::BitcastConvertOp>(intTy, y);
1164
1165 // The result is NaN if either "x" or "y" are NaN.
1166 auto xIsNan = b.create<mhlo::CompareOp>(x, x, mhlo::ComparisonDirection::NE);
1167 auto yIsNan = b.create<mhlo::CompareOp>(y, y, mhlo::ComparisonDirection::NE);
1168 auto nanInput = b.create<mhlo::OrOp>(xIsNan, yIsNan);
1169 auto resultForNan = getConstantLike(
1170 rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
1171 auto resultForNanAsInt =
1172 b.create<mhlo::BitcastConvertOp>(intTy, resultForNan);
1173
1174 // The sign bit is the MSB.
1175 const int64_t signBit = int64_t{1} << (bitwidth - 1);
1176 // Discard the sign bit to make the result non-negative.
1177 auto signMask = getConstantLike(rewriter, loc, signBit, xAsInt);
1178 auto negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt);
1179 auto xAbs = b.create<mhlo::AndOp>(xAsInt, negatedSignMask);
1180 auto yAbs = b.create<mhlo::AndOp>(yAsInt, negatedSignMask);
1181
1182 // When both "x" and "y" are equal, the result is "y".
1183 auto xAndYAreEqual =
1184 b.create<mhlo::CompareOp>(x, y, mhlo::ComparisonDirection::EQ);
1185 auto resultForEqual = yAsInt;
1186
1187 // When both "x" and "y" are 0, the result is "y". This is a separate case
1188 // from above because "x" and "y" might have a different sign.
1189 auto zero = getConstantLike(rewriter, loc, 0, xAsInt);
1190 auto xIsZero =
1191 b.create<mhlo::CompareOp>(xAbs, zero, mhlo::ComparisonDirection::EQ);
1192 auto yIsZero =
1193 b.create<mhlo::CompareOp>(yAbs, zero, mhlo::ComparisonDirection::EQ);
1194 auto resultForBothZero = yAsInt;
1195
1196 auto xSign = b.create<mhlo::AndOp>(xAsInt, signMask);
1197 auto ySign = b.create<mhlo::AndOp>(yAsInt, signMask);
1198
1199 // If from == 0 && to != 0, we need to return the smallest subnormal number
1200 // signed like "to".
1201 auto one = getConstantLike(rewriter, loc, 1, xAsInt);
1202 auto resultForXZeroYNonZero = b.create<mhlo::OrOp>(ySign, one);
1203
1204 // If the sign of "x" and "y" disagree:
1205 // - we need to make the magnitude of "from" smaller so that it is closer to
1206 // zero.
1207 //
1208 // Otherwise the signs agree:
1209 // - "x" with a magnitude larger than "y" means we need to make the magnitude
1210 // smaller.
1211 // - "x" with a magnitude smaller than "y" means we need to make the magnitude
1212 // larger.
1213 auto signsDisagree =
1214 b.create<mhlo::CompareOp>(xSign, ySign, mhlo::ComparisonDirection::NE);
1215 auto xMagnitudeLargerThanY =
1216 b.create<mhlo::CompareOp>(xAbs, yAbs, mhlo::ComparisonDirection::GT);
1217 auto resultHasSmallerMagnitude =
1218 b.create<mhlo::OrOp>(xMagnitudeLargerThanY, signsDisagree);
1219 auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt);
1220 auto magnitudeAdjustment =
1221 b.create<mhlo::SelectOp>(resultHasSmallerMagnitude, minusOne, one);
1222 Value result = b.create<mhlo::AddOp>(xAsInt, magnitudeAdjustment);
1223 // Handle from == +-0.
1224 result = b.create<mhlo::SelectOp>(
1225 xIsZero,
1226 b.create<mhlo::SelectOp>(yIsZero, resultForBothZero,
1227 resultForXZeroYNonZero),
1228 result);
1229 // Handle from == to.
1230 result = b.create<mhlo::SelectOp>(xAndYAreEqual, resultForEqual, result);
1231 // Handle isnan(x) || isnan(y).
1232 result = b.create<mhlo::SelectOp>(nanInput, resultForNanAsInt, result);
1233
1234 // Cast back to the original type.
1235 return b.create<mhlo::BitcastConvertOp>(resultTy, result);
1236 }
1237
1238 struct ConvertNextAfterOp : public OpConversionPattern<NextAfterOp> {
1239 using OpConversionPattern<NextAfterOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertNextAfterOp1240 LogicalResult matchAndRewrite(
1241 NextAfterOp op, OpAdaptor adaptor,
1242 ConversionPatternRewriter &rewriter) const override {
1243 rewriter.replaceOp(
1244 op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands()));
1245 return success();
1246 }
1247 };
1248
1249 struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
1250 using OpConversionPattern<PolygammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertPolygammaOp1251 LogicalResult matchAndRewrite(
1252 PolygammaOp op, OpAdaptor adaptor,
1253 ConversionPatternRewriter &rewriter) const override {
1254 Location loc = op.getLoc();
1255 FloatType minPrecisionTy = rewriter.getF32Type();
1256 rewriter.replaceOp(
1257 op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
1258 minPrecisionTy, &materializePolygamma));
1259 return success();
1260 }
1261 };
1262
1263 // Sinh(x) = (e^x - e^-x) / 2
1264 // = e^(x + log(1/2)) - e^(-x + log(1/2)).
1265 //
1266 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
1267 // inf.
1268 //
1269 // This incorrectly overflows to +/-inf for two f32 input values, namely
1270 // +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
1271 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
1272 // we deem this acceptable.
materializeSinhApproximationForLargeX(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1273 Value materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
1274 Location loc, ValueRange operands) {
1275 SinhOp::Adaptor transformed(operands);
1276 Value x = transformed.operand();
1277
1278 Value logOneHalf =
1279 rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
1280 Value expAdd = rewriter.create<mhlo::ExpOp>(
1281 loc, rewriter.create<mhlo::AddOp>(loc, x, logOneHalf));
1282 Value expSub = rewriter.create<mhlo::ExpOp>(
1283 loc, rewriter.create<mhlo::SubtractOp>(loc, logOneHalf, x));
1284 return rewriter.create<mhlo::SubtractOp>(loc, expAdd, expSub);
1285 }
1286
1287 // Express `sinh` as
1288 // sinh(x) = (e^x - e^-x) / 2 if |x| < 1
1289 // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
materializeSinhApproximation(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1290 Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
1291 Location loc, ValueRange operands) {
1292 Value largeSinhResult =
1293 materializeSinhApproximationForLargeX(rewriter, loc, operands);
1294
1295 SinhOp::Adaptor transformed(operands);
1296 Value x = transformed.operand();
1297
1298 // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
1299 // 0.
1300 // Rewrite this to avoid that. We use expm1(x) because that preserves the
1301 // first order term of the taylor series of e^x.
1302 // (e^(x) - e^(-x)) / 2. =
1303 // (e^(x) - 1 + 1 - e^(-x)) / 2.
1304 // (expm1(x) + (e^(x) - 1) / e^x) / 2.
1305 // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
1306 Value expm1 = rewriter.create<mhlo::Expm1Op>(loc, x);
1307 Value one = getConstantLike(rewriter, loc, 1.0, x);
1308 Value oneHalf = getConstantLike(rewriter, loc, 0.5, x);
1309 Value expm1PlusOne = rewriter.create<mhlo::AddOp>(loc, expm1, one);
1310 Value ratio = rewriter.create<mhlo::DivOp>(loc, expm1, expm1PlusOne);
1311 Value sum = rewriter.create<mhlo::AddOp>(loc, expm1, ratio);
1312 Value smallSinhResult = rewriter.create<mhlo::MulOp>(loc, oneHalf, sum);
1313
1314 Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
1315 Value absXLtOne = rewriter.create<mhlo::CompareOp>(
1316 loc, absX, one, mhlo::ComparisonDirection::LT);
1317 return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, smallSinhResult,
1318 largeSinhResult);
1319 }
1320
1321 struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
1322 using OpConversionPattern<SinhOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertSinhOp1323 LogicalResult matchAndRewrite(
1324 SinhOp op, OpAdaptor adaptor,
1325 ConversionPatternRewriter &rewriter) const override {
1326 Value x = adaptor.operand();
1327 if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
1328 rewriter.replaceOp(op, materializeSinhApproximationForLargeX(
1329 rewriter, op.getLoc(), adaptor.getOperands()));
1330 return success();
1331 }
1332 rewriter.replaceOp(
1333 op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
1334 rewriter.getF32Type(),
1335 &materializeSinhApproximation));
1336 return success();
1337 }
1338 };
1339
materializeTan(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1340 Value materializeTan(ConversionPatternRewriter &rewriter, Location loc,
1341 ValueRange operands) {
1342 TanOp::Adaptor transformed(operands);
1343 return rewriter.create<mhlo::DivOp>(
1344 loc, rewriter.create<mhlo::SineOp>(loc, transformed.operand()),
1345 rewriter.create<mhlo::CosineOp>(loc, transformed.operand()));
1346 }
1347
1348 struct ConvertTanOp : public OpConversionPattern<TanOp> {
1349 using OpConversionPattern<TanOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertTanOp1350 LogicalResult matchAndRewrite(
1351 TanOp op, OpAdaptor adaptor,
1352 ConversionPatternRewriter &rewriter) const override {
1353 rewriter.replaceOp(
1354 op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
1355 rewriter.getF32Type(), &materializeTan));
1356 return success();
1357 }
1358 };
1359
1360 // Converts chlo.top_k to MHLO iota, sort, and slice ops.
1361 //
1362 // chlo.top_k sorts along last dimension of the input tensor and then returns
1363 // the top K components' values and indices. This is translated into a few
1364 // ops in MHLO: first generating an integer sequence for the indices,
1365 // then sort both the original input tensor and the indices togheter, and
1366 // at last slice out the top K components.
1367 //
1368 // For example, for the following IR:
1369 //
1370 // %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> ->
1371 // (tensor<16x8xf32>, tensor<16x8xi32>)
1372 //
1373 // We will get:
1374 //
1375 // %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
1376 // %2 = "mhlo.sort"(%input, %1) ({
1377 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
1378 // %arg3: tensor<i32>, %arg4: tensor<i32>):
1379 // %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
1380 // "mhlo.return"(%7) : (tensor<i1>) -> ()
1381 // }) {dimension = 1 : i64, is_stable = true} : ...
1382 // %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ...
1383 // %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ...
1384 // %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
1385 // start_indices dense<0> : tensor<2xi64>,
1386 // strides = dense<1> : tensor<2xi64>} :
1387 // (tensor<16x16xf32>) -> tensor<16x8xf32>
1388 // %6 = "mhlo.slice"(%4) ...
1389 struct ConvertTopKOp : public OpConversionPattern<TopKOp> {
1390 using OpConversionPattern<TopKOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertTopKOp1391 LogicalResult matchAndRewrite(
1392 TopKOp op, OpAdaptor /*adaptor*/,
1393 ConversionPatternRewriter &rewriter) const override {
1394 // The last dimension of the operand's shape should be known so we can have
1395 // clamped end_indices for slices. This is verified by the op.
1396 auto operandType = op.operand().getType().cast<RankedTensorType>();
1397 int64_t operandRank = operandType.getRank();
1398 int64_t lastDimIndex = operandRank - 1;
1399 int64_t lastDimSize = operandType.getDimSize(lastDimIndex);
1400 assert(lastDimSize != ShapedType::kDynamicSize);
1401
1402 // Create an Iota op for indices.
1403 auto i32Type = rewriter.getIntegerType(32);
1404 Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type);
1405 Value iotaOp = rewriter.create<mhlo::IotaOp>(
1406 op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex));
1407
1408 // Create the sort op. It takes two inputs, one for the original input, the
1409 // other for the indices. Use TOTALORDER comparison type instead of the
1410 // default comparison if the element type is of type float.
1411 Type elementType = operandType.getElementType();
1412 auto sortOp = createSortOp(&rewriter, op.getLoc(), {op.operand(), iotaOp},
1413 {elementType, i32Type}, lastDimIndex,
1414 /*isStable=*/true,
1415 /*direction=*/mhlo::ComparisonDirection::GT);
1416
1417 // Get the sorted input and index tuple element.
1418 auto tupleFirstElement = sortOp.getResult(0);
1419 auto tupleSecondElement = sortOp.getResult(1);
1420
1421 SmallVector<int64_t, 4> beginIndices(operandRank, 0);
1422 auto endIndices = llvm::to_vector<4>(operandType.getShape());
1423 endIndices.back() = std::min(static_cast<int64_t>(op.k()), lastDimSize);
1424 SmallVector<int64_t, 4> strides(operandRank, 1);
1425
1426 // Get the slice for the top K elements.
1427 auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type());
1428 Value values = rewriter.create<mhlo::SliceOp>(
1429 op.getLoc(), tupleFirstElement,
1430 DenseIntElementsAttr::get(indicesTy, beginIndices),
1431 DenseIntElementsAttr::get(indicesTy, endIndices),
1432 DenseIntElementsAttr::get(indicesTy, strides));
1433 Value indices = rewriter.create<mhlo::SliceOp>(
1434 op.getLoc(), tupleSecondElement,
1435 DenseIntElementsAttr::get(indicesTy, beginIndices),
1436 DenseIntElementsAttr::get(indicesTy, endIndices),
1437 DenseIntElementsAttr::get(indicesTy, strides));
1438
1439 rewriter.replaceOp(op, {values, indices});
1440 return success();
1441 }
1442 };
1443
1444 struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
1445 using OpConversionPattern<ZetaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertZetaOp1446 LogicalResult matchAndRewrite(
1447 ZetaOp op, OpAdaptor adaptor,
1448 ConversionPatternRewriter &rewriter) const override {
1449 Location loc = op.getLoc();
1450 FloatType minPrecisionTy = rewriter.getF32Type();
1451 rewriter.replaceOp(
1452 op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
1453 minPrecisionTy, &materializeZeta));
1454 return success();
1455 }
1456 };
1457
1458 struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
1459 using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertSelectOp1460 LogicalResult matchAndRewrite(
1461 BroadcastSelectOp op, OpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter) const override {
1463 // Only support ranked operands.
1464 Value pred = adaptor.pred();
1465 Value onTrue = adaptor.on_true();
1466 Value onFalse = adaptor.on_false();
1467 auto predType = pred.getType().dyn_cast<RankedTensorType>();
1468 auto onTrueType = onTrue.getType().dyn_cast<RankedTensorType>();
1469 auto onFalseType = onFalse.getType().dyn_cast<RankedTensorType>();
1470 auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
1471 if (!predType || !onTrueType || !onFalseType || !resultType) {
1472 return failure();
1473 }
1474
1475 auto loc = op.getLoc();
1476
1477 Value predShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
1478 Value onTrueShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onTrue);
1479 Value onFalseShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onFalse);
1480 int64_t resultRank = std::max(
1481 {predType.getRank(), onTrueType.getRank(), onFalseType.getRank()});
1482
1483 Value broadcastableCstr = rewriter.createOrFold<shape::CstrBroadcastableOp>(
1484 loc, ValueRange{predShape, onTrueShape, onFalseShape});
1485 auto assumingOp = rewriter.create<shape::AssumingOp>(
1486 loc, ArrayRef<Type>{resultType}, broadcastableCstr);
1487
1488 OpBuilder::InsertionGuard guard(rewriter);
1489 rewriter.createBlock(&assumingOp.getDoRegion());
1490
1491 Value resultExtents = rewriter.createOrFold<shape::BroadcastOp>(
1492 loc, shape::getExtentTensorType(op.getContext()),
1493 ValueRange{predShape, onTrueShape, onFalseShape},
1494 /*error=*/nullptr);
1495 auto shapeType =
1496 RankedTensorType::get({resultRank}, rewriter.getIndexType());
1497 resultExtents =
1498 rewriter.createOrFold<tensor::CastOp>(loc, shapeType, resultExtents);
1499
1500 Value broadcastedPred = pred;
1501 // Pred has an implicit broadcast for scalars, so use that when convenient.
1502 if (predType.getRank() > 0) {
1503 auto predBroadcastDimensions = llvm::to_vector<4>(
1504 llvm::seq<int64_t>(resultRank - predType.getRank(), resultRank));
1505 broadcastedPred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1506 loc,
1507 RankedTensorType::get(resultType.getShape(),
1508 predType.getElementType()),
1509 pred, resultExtents,
1510 rewriter.getI64TensorAttr(predBroadcastDimensions));
1511 }
1512 auto onTrueBroadcastDimensions = llvm::to_vector<4>(
1513 llvm::seq<int64_t>(resultRank - onTrueType.getRank(), resultRank));
1514 Value broadcastedOnTrue = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1515 loc,
1516 RankedTensorType::get(resultType.getShape(),
1517 onTrueType.getElementType()),
1518 onTrue, resultExtents,
1519 rewriter.getI64TensorAttr(onTrueBroadcastDimensions));
1520 auto onFalseBroadcastDimensions = llvm::to_vector<4>(
1521 llvm::seq<int64_t>(resultRank - onFalseType.getRank(), resultRank));
1522 Value broadcastedOnFalse = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1523 loc,
1524 RankedTensorType::get(resultType.getShape(),
1525 onFalseType.getElementType()),
1526 onFalse, resultExtents,
1527 rewriter.getI64TensorAttr(onFalseBroadcastDimensions));
1528
1529 // And generate the final non-broadcasted ternary op.
1530 Value finalResult =
1531 rewriter.create<mhlo::SelectOp>(loc, resultType, broadcastedPred,
1532 broadcastedOnTrue, broadcastedOnFalse);
1533 rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
1534 rewriter.replaceOp(op, {assumingOp.getResult(0)});
1535 return success();
1536 }
1537 };
1538
1539 // Converts binary ops that statically are determined to not broadcast directly
1540 // to the corresponding mhlo non-broadcasting op.
1541 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1542 struct ConvertTrivialNonBroadcastBinaryOp
1543 : public OpConversionPattern<ChloOpTy> {
1544 using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertTrivialNonBroadcastBinaryOp1545 LogicalResult matchAndRewrite(
1546 ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
1547 ConversionPatternRewriter &rewriter) const override {
1548 // Only rewrite for statically determinable non-broadcasting cases.
1549 auto lhsType =
1550 adaptor.lhs().getType().template dyn_cast<RankedTensorType>();
1551 auto rhsType =
1552 adaptor.rhs().getType().template dyn_cast<RankedTensorType>();
1553 if (!lhsType || !rhsType) return failure();
1554
1555 // Requires rank broadcast.
1556 if (lhsType.getRank() != rhsType.getRank()) return failure();
1557 // Any dynamic dimension may require broadcasting and requires more
1558 // analysis.
1559 if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape())
1560 return failure();
1561
1562 for (auto extents : llvm::zip(lhsType.getShape(), rhsType.getShape())) {
1563 auto lhsExtent = std::get<0>(extents);
1564 auto rhsExtent = std::get<1>(extents);
1565 if (lhsExtent != rhsExtent) {
1566 return failure();
1567 }
1568 }
1569
1570 rewriter.replaceOp(op,
1571 {Adaptor::createOp(op, op.getResult().getType(),
1572 adaptor.getOperands(), rewriter)});
1573 return success();
1574 }
1575 };
1576
1577 // Converts a binary op with ranked broadcasting operands to explicitly
1578 // broadcast and invoke the corresponding mhlo non-broadcasting op.
1579 // Note that dynamic broadcasting supported by this pattern is only valid for
1580 // "numpy" broadcasting semantics as defined here:
1581 // https://docs.scipy.org/doc/numpy/reference/ufuncs.html
1582 // Specifically, this includes the following cases:
1583 // - Same rank broadcast (operands have the same static rank).
1584 // - Different-rank broadcast, either without a broadcast_dims attribte or
1585 // with the broadcast_dims attribute set to map to a prefix padding.
1586 // - Legal combinations of degenerate (1-dim) implicit broadcasting.
1587 // The restriction on broadcast_dims derives from the definition of the
1588 // `shape.broadcast` op, which only supports prefix-padding.
1589 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1590 struct ConvertRankedDynamicBroadcastBinaryOp
1591 : public OpConversionPattern<ChloOpTy> {
1592 using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertRankedDynamicBroadcastBinaryOp1593 LogicalResult matchAndRewrite(
1594 ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
1595 ConversionPatternRewriter &rewriter) const override {
1596 // Only support ranked operands.
1597 Value lhs = adaptor.lhs();
1598 Value rhs = adaptor.rhs();
1599 auto lhsType = lhs.getType().dyn_cast<RankedTensorType>();
1600 auto rhsType = rhs.getType().dyn_cast<RankedTensorType>();
1601 auto resultType =
1602 op.getResult().getType().template dyn_cast<RankedTensorType>();
1603 if (!lhsType || !rhsType || !resultType) return failure();
1604
1605 // Check for "numpy"-style rank broadcast.
1606 auto broadcastDimensions = op.broadcast_dimensions();
1607 if (broadcastDimensions &&
1608 !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, *broadcastDimensions)) {
1609 // Note: It is unclear whether the general specification of explicit
1610 // broadcast_dimensions on binary ops is a feature we want to carry
1611 // forward. While it can technically be implemented for ranked-dynamic,
1612 // it is incompatible with unranked inputs. If this warning is emitted
1613 // in real programs, it is an indication that the feature should be
1614 // implemented versus just falling back on the more standard definition
1615 // of numpy-like prefix-padding.
1616 op.emitWarning() << "unsupported non prefix-padded dynamic rank "
1617 << "broadcast_dimensions = " << *broadcastDimensions;
1618 return failure();
1619 }
1620
1621 // Compute result shape.
1622 auto loc = op.getLoc();
1623
1624 // Insert a constraint on the shapes being broadcastable and insert all
1625 // future code into an assuming block reliant on the constraint.
1626 Value lhsShape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
1627 Value rhsShape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
1628 auto broadcastableCstr =
1629 rewriter.create<shape::CstrBroadcastableOp>(loc, lhsShape, rhsShape);
1630 auto assumingOp = rewriter.create<shape::AssumingOp>(
1631 loc, ArrayRef<Type>{resultType}, broadcastableCstr.getResult());
1632
1633 OpBuilder::InsertionGuard guard(rewriter);
1634 rewriter.createBlock(&assumingOp.getDoRegion());
1635
1636 int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank());
1637 Value resultExtents =
1638 hlo::computeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
1639 rewriter);
1640
1641 // Note that we unconditionally emit DynamicBroadcastInDim ops and let
1642 // downstream canonicalizations fold them away if possible. This is
1643 // because, in the dynamic case, there are many corner cases regarding
1644 // when it is safe to omit, and some of them require analysis to prove
1645 // properly.
1646 auto lhsBroadcastDimensions = llvm::to_vector<4>(
1647 llvm::seq<int64_t>(resultRank - lhsType.getRank(), resultRank));
1648 Value broadcastedLhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1649 loc,
1650 RankedTensorType::get(resultType.getShape(), lhsType.getElementType()),
1651 lhs, resultExtents, rewriter.getI64TensorAttr(lhsBroadcastDimensions));
1652 auto rhsBroadcastDimensions = llvm::to_vector<4>(
1653 llvm::seq<int64_t>(resultRank - rhsType.getRank(), resultRank));
1654 Value broadcastedRhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1655 loc,
1656 RankedTensorType::get(resultType.getShape(), rhsType.getElementType()),
1657 rhs, resultExtents, rewriter.getI64TensorAttr(rhsBroadcastDimensions));
1658
1659 // And generate the final non-broadcasted binary op.
1660 Value finalResult = Adaptor::createOp(
1661 op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter);
1662 rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
1663 rewriter.replaceOp(op, {assumingOp.getResult(0)});
1664 return success();
1665 }
1666 };
1667
1668 class ConvertDynamicReshapeOp
1669 : public OpRewritePattern<chlo::DynamicReshapeOp> {
1670 public:
1671 using OpRewritePattern::OpRewritePattern;
1672
matchAndRewrite(chlo::DynamicReshapeOp op,PatternRewriter & rewriter) const1673 LogicalResult matchAndRewrite(chlo::DynamicReshapeOp op,
1674 PatternRewriter &rewriter) const override {
1675 auto loc = op.getLoc();
1676 auto tensor = op.operand();
1677 auto shape = op.output_shape();
1678
1679 auto shapeTy = shape.getType().cast<ShapedType>();
1680 auto resultTy = op.getType().cast<ShapedType>();
1681
1682 Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
1683 Value numEls = rewriter.create<shape::NumElementsOp>(loc, inputShape);
1684 Value cstr = rewriter.create<mhlo::CstrReshapableOp>(loc, numEls, shape);
1685 rewriter.replaceOpWithNewOp<shape::AssumingOp>(
1686 op, cstr, [&](OpBuilder &b, Location l) {
1687 Value computedShape =
1688 b.create<mhlo::ComputeReshapeShapeOp>(l, shapeTy, numEls, shape);
1689 SmallVector<Value> result;
1690 result.push_back(b.create<mhlo::DynamicReshapeOp>(l, resultTy, tensor,
1691 computedShape));
1692 return result;
1693 });
1694
1695 return success();
1696 }
1697 };
1698
1699 #include "generated_chlo_legalize_to_hlo.inc"
1700 } // namespace
1701
populateChloBroadcastingPatterns(MLIRContext * context,RewritePatternSet * patterns)1702 void populateChloBroadcastingPatterns(MLIRContext *context,
1703 RewritePatternSet *patterns) {
1704 // Instantiate conversion templates for conforming binary elementwise ops
1705 // that do not have different dtypes between operands and results and do
1706 // not have special attributes that need to be preserved.
1707 populateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
1708 context, patterns, 10);
1709 populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
1710 context, patterns, 5);
1711 patterns
1712 ->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
1713 context);
1714 }
1715
populateDecomposeChloPatterns(MLIRContext * context,RewritePatternSet * patterns)1716 void populateDecomposeChloPatterns(MLIRContext *context,
1717 RewritePatternSet *patterns) {
1718 populateWithGenerated(*patterns);
1719
1720 // Other patterns.
1721 // clang-format off
1722 patterns->add<ConvertBesselI1eOp,
1723 ConvertCoshOp,
1724 ConvertDigammaOp,
1725 ConvertErfOp,
1726 ConvertErfcOp,
1727 ConvertLgammaOp,
1728 ConvertNextAfterOp,
1729 ConvertPolygammaOp,
1730 ConvertSinhOp,
1731 ConvertTanOp,
1732 ConvertTopKOp,
1733 ConvertZetaOp>(context);
1734 // clang-format on
1735 }
1736
1737 } // namespace chlo
1738 } // namespace mlir
1739