1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Enable the use of M_* math constants.
17 // NOTE: this must be first in the file to ensure that if cmath is transitively
18 // included by any other header it has the define set on first processing.
19 // https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
20 #define _USE_MATH_DEFINES
21 #include <algorithm>
22 #include <cmath>
23 #include <numeric>
24 #include <vector>
25 
26 #include "llvm/ADT/SmallVector.h"
27 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
28 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
29 #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
30 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
31 #include "mlir-hlo/utils/broadcast_utils.h"
32 #include "mlir-hlo/utils/hlo_utils.h"
33 #include "mlir/Dialect/Complex/IR/Complex.h"
34 #include "mlir/Dialect/Func/IR/FuncOps.h"
35 #include "mlir/Dialect/SCF/IR/SCF.h"
36 #include "mlir/Dialect/Shape/IR/Shape.h"
37 #include "mlir/Dialect/Tensor/IR/Tensor.h"
38 #include "mlir/IR/Attributes.h"
39 #include "mlir/IR/BuiltinTypes.h"
40 #include "mlir/IR/ImplicitLocOpBuilder.h"
41 #include "mlir/IR/MLIRContext.h"
42 #include "mlir/IR/OperationSupport.h"
43 #include "mlir/IR/PatternMatch.h"
44 #include "mlir/Transforms/DialectConversion.h"
45 
46 namespace mlir {
47 namespace chlo {
48 namespace {
49 
50 struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> {
51   using OpConversionPattern<ConstantLikeOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertConstantLikeOp52   LogicalResult matchAndRewrite(
53       ConstantLikeOp op, OpAdaptor adaptor,
54       ConversionPatternRewriter &rewriter) const override {
55     auto resultTy = op.getType().cast<ShapedType>();
56 
57     // Unranked uses are not supported.
58     if (!resultTy.hasRank()) return failure();
59 
60     // Lower to MHLO constant if statically shaped.
61     if (resultTy.hasStaticShape()) {
62       auto complexAttr = op.value().dyn_cast<complex::NumberAttr>();
63       auto attr = complexAttr
64                       ? DenseElementsAttr::get(resultTy, complexAttr.getValue())
65                       : DenseElementsAttr::get(resultTy, op.value());
66       rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, attr);
67       return success();
68     }
69 
70     // Lower to broadcasted constant.
71     auto loc = op.getLoc();
72     Value constant = rewriter.create<mhlo::ConstantOp>(loc, op.value());
73     Value shape = rewriter.create<shape::ShapeOfOp>(loc, adaptor.operand());
74     rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>(
75         op, resultTy, constant, shape, rewriter.getI64TensorAttr({}));
76     return success();
77   }
78 };
79 
80 template <typename FTy>
materializeChebyshevPolynomialApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,ArrayRef<FTy> coefficients)81 Value materializeChebyshevPolynomialApproximation(
82     ConversionPatternRewriter &rewriter, Location loc, Value x,
83     ArrayRef<FTy> coefficients) {
84   Value b0 = chlo::getConstantLike(rewriter, loc, 0.0, x);
85   Value b1 = chlo::getConstantLike(rewriter, loc, 0.0, x);
86   Value b2 = chlo::getConstantLike(rewriter, loc, 0.0, x);
87   for (FTy c : coefficients) {
88     b2 = b1;
89     b1 = b0;
90     b0 = rewriter.create<mhlo::MulOp>(loc, x.getType(), x, b1);
91     b0 = rewriter.create<mhlo::SubtractOp>(loc, x.getType(), b0, b2);
92     b0 = rewriter.create<mhlo::AddOp>(
93         loc, x.getType(), b0, chlo::getConstantLike(rewriter, loc, c, x));
94   }
95   Value result = rewriter.create<mhlo::SubtractOp>(loc, x.getType(), b0, b2);
96   result = rewriter.create<mhlo::MulOp>(
97       loc, x.getType(), result, chlo::getConstantLike(rewriter, loc, 0.5, x));
98   return result;
99 }
100 
101 template <typename FTy>
materializeBesselI1eApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,ArrayRef<FTy> kI1eCoeffsA,ArrayRef<FTy> kI1eCoeffsB)102 Value materializeBesselI1eApproximation(ConversionPatternRewriter &rewriter,
103                                         Location loc, Value x,
104                                         ArrayRef<FTy> kI1eCoeffsA,
105                                         ArrayRef<FTy> kI1eCoeffsB) {
106   Value z = rewriter.create<mhlo::AbsOp>(loc, x);
107   Value half = chlo::getConstantLike(rewriter, loc, 0.5, x);
108   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
109   Value thirtyTwo = chlo::getConstantLike(rewriter, loc, 32.0, x);
110   Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
111 
112   Value tmp = rewriter.create<mhlo::MulOp>(loc, half, z);
113   tmp = rewriter.create<mhlo::SubtractOp>(loc, tmp, two);
114 
115   Value xLe8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
116                                                            kI1eCoeffsA);
117   xLe8 = rewriter.create<mhlo::MulOp>(loc, z, xLe8);
118 
119   tmp = rewriter.create<mhlo::DivOp>(loc, thirtyTwo, z);
120   tmp = rewriter.create<mhlo::SubtractOp>(loc, tmp, two);
121   Value xGt8 = materializeChebyshevPolynomialApproximation(rewriter, loc, tmp,
122                                                            kI1eCoeffsB);
123   xGt8 = rewriter.create<mhlo::DivOp>(loc, xGt8,
124                                       rewriter.create<mhlo::SqrtOp>(loc, z));
125 
126   Value isLe8 = rewriter.create<mhlo::CompareOp>(loc, z, eight,
127                                                  mhlo::ComparisonDirection::LE);
128 
129   Value select = rewriter.create<mhlo::SelectOp>(loc, isLe8, xLe8, xGt8);
130   return rewriter.create<mhlo::MulOp>(
131       loc, rewriter.create<mhlo::SignOp>(loc, x), select);
132 }
133 
materializeBesselI1eApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)134 Value materializeBesselI1eApproximationF32(ConversionPatternRewriter &rewriter,
135                                            Location loc, ValueRange args) {
136   Value x = args.front();
137   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
138          "expect f32 element type");
139   const float kI1eCoeffsA[] = {
140       9.38153738649577178388E-9f, -4.44505912879632808065E-8f,
141       2.00329475355213526229E-7f, -8.56872026469545474066E-7f,
142       3.47025130813767847674E-6f, -1.32731636560394358279E-5f,
143       4.78156510755005422638E-5f, -1.61760815825896745588E-4f,
144       5.12285956168575772895E-4f, -1.51357245063125314899E-3f,
145       4.15642294431288815669E-3f, -1.05640848946261981558E-2f,
146       2.47264490306265168283E-2f, -5.29459812080949914269E-2f,
147       1.02643658689847095384E-1f, -1.76416518357834055153E-1f,
148       2.52587186443633654823E-1f};
149 
150   const float kI1eCoeffsB[] = {
151       -3.83538038596423702205E-9f, -2.63146884688951950684E-8f,
152       -2.51223623787020892529E-7f, -3.88256480887769039346E-6f,
153       -1.10588938762623716291E-4f, -9.76109749136146840777E-3f,
154       7.78576235018280120474E-1f};
155 
156   return materializeBesselI1eApproximation<float>(rewriter, loc, x, kI1eCoeffsA,
157                                                   kI1eCoeffsB);
158 }
159 
materializeBesselI1eApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)160 Value materializeBesselI1eApproximationF64(ConversionPatternRewriter &rewriter,
161                                            Location loc, ValueRange args) {
162   Value x = args.front();
163   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
164          "expect f64 element type");
165 
166   const double kI1eCoeffsA[] = {
167       2.77791411276104639959E-18, -2.11142121435816608115E-17,
168       1.55363195773620046921E-16, -1.10559694773538630805E-15,
169       7.60068429473540693410E-15, -5.04218550472791168711E-14,
170       3.22379336594557470981E-13, -1.98397439776494371520E-12,
171       1.17361862988909016308E-11, -6.66348972350202774223E-11,
172       3.62559028155211703701E-10, -1.88724975172282928790E-9,
173       9.38153738649577178388E-9,  -4.44505912879632808065E-8,
174       2.00329475355213526229E-7,  -8.56872026469545474066E-7,
175       3.47025130813767847674E-6,  -1.32731636560394358279E-5,
176       4.78156510755005422638E-5,  -1.61760815825896745588E-4,
177       5.12285956168575772895E-4,  -1.51357245063125314899E-3,
178       4.15642294431288815669E-3,  -1.05640848946261981558E-2,
179       2.47264490306265168283E-2,  -5.29459812080949914269E-2,
180       1.02643658689847095384E-1,  -1.76416518357834055153E-1,
181       2.52587186443633654823E-1};
182 
183   const double kI1eCoeffsB[] = {
184       7.51729631084210481353E-18,  4.41434832307170791151E-18,
185       -4.65030536848935832153E-17, -3.20952592199342395980E-17,
186       2.96262899764595013876E-16,  3.30820231092092828324E-16,
187       -1.88035477551078244854E-15, -3.81440307243700780478E-15,
188       1.04202769841288027642E-14,  4.27244001671195135429E-14,
189       -2.10154184277266431302E-14, -4.08355111109219731823E-13,
190       -7.19855177624590851209E-13, 2.03562854414708950722E-12,
191       1.41258074366137813316E-11,  3.25260358301548823856E-11,
192       -1.89749581235054123450E-11, -5.58974346219658380687E-10,
193       -3.83538038596423702205E-9,  -2.63146884688951950684E-8,
194       -2.51223623787020892529E-7,  -3.88256480887769039346E-6,
195       -1.10588938762623716291E-4,  -9.76109749136146840777E-3,
196       7.78576235018280120474E-1};
197 
198   return materializeBesselI1eApproximation<double>(rewriter, loc, x,
199                                                    kI1eCoeffsA, kI1eCoeffsB);
200 }
201 
materializeWithUpcast(ConversionPatternRewriter & rewriter,Location loc,ValueRange args,FloatType minPrecisionTy,Value callback (ConversionPatternRewriter &,Location,ValueRange))202 Value materializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
203                             ValueRange args, FloatType minPrecisionTy,
204                             Value callback(ConversionPatternRewriter &,
205                                            Location, ValueRange)) {
206   auto originalTy = getElementTypeOrSelf(args.front().getType());
207   auto floatOriginalTy = originalTy.dyn_cast<FloatType>();
208   bool needsUpcast =
209       floatOriginalTy && floatOriginalTy.getWidth() < minPrecisionTy.getWidth();
210 
211   // Upcast arguments if necessary.
212   llvm::SmallVector<Value, 2> castedArgs;
213   if (needsUpcast) {
214     for (Value a : args) {
215       castedArgs.push_back(
216           rewriter.create<mhlo::ConvertOp>(loc, a, minPrecisionTy));
217     }
218     args = castedArgs;
219   }
220 
221   Value result = callback(rewriter, loc, args);
222 
223   // Cast back if necessary.
224   if (needsUpcast) {
225     result = rewriter.create<mhlo::ConvertOp>(loc, result, originalTy);
226   }
227 
228   return result;
229 }
230 
231 struct ConvertBesselI1eOp : public OpConversionPattern<BesselI1eOp> {
232   using OpConversionPattern<BesselI1eOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertBesselI1eOp233   LogicalResult matchAndRewrite(
234       BesselI1eOp op, OpAdaptor adaptor,
235       ConversionPatternRewriter &rewriter) const override {
236     Location loc = op.getLoc();
237     Value x = adaptor.operand();
238     Type ty = x.getType().cast<ShapedType>().getElementType();
239 
240     // For now, we support only f64, f32, f16 and bf16.
241     // See https://www.tensorflow.org/api_docs/python/tf/math/bessel_i1e
242     if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16())
243       return failure();
244 
245     if (ty.isF64()) {
246       rewriter.replaceOp(
247           op, materializeBesselI1eApproximationF64(rewriter, loc, x));
248       return success();
249     }
250 
251     rewriter.replaceOp(
252         op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
253                                   rewriter.getF32Type(),
254                                   &materializeBesselI1eApproximationF32));
255     return success();
256   }
257 };
258 
259 template <typename FTy>
materializePolynomialApproximation(ConversionPatternRewriter & rewriter,Location loc,Value x,ArrayRef<FTy> coefficients)260 Value materializePolynomialApproximation(ConversionPatternRewriter &rewriter,
261                                          Location loc, Value x,
262                                          ArrayRef<FTy> coefficients) {
263   Value poly = chlo::getConstantLike(rewriter, loc, 0.0, x);
264   for (FTy c : coefficients) {
265     poly = rewriter.create<mhlo::MulOp>(loc, x.getType(), poly, x);
266     poly = rewriter.create<mhlo::AddOp>(
267         loc, x.getType(), poly, chlo::getConstantLike(rewriter, loc, c, x));
268   }
269   return poly;
270 }
271 
272 // Precondition is |x| >= 1. Use erf approximation, otherwise.
273 //
274 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
275 // argument and derive the final approximation for all |x| >= 1.
276 // This implementation is based on Cephes.
materializeErfcApproximationF64ForMagnituteGeOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)277 Value materializeErfcApproximationF64ForMagnituteGeOne(
278     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
279   Value x = args.front();
280   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
281          "expect f64 element type");
282   const double kMaxlog = 7.09782712893383996843E2;
283   const double kErfcPCoefficients[] = {
284       2.46196981473530512524E-10, 5.64189564831068821977E-1,
285       7.46321056442269912687E0,   4.86371970985681366614E1,
286       1.96520832956077098242E2,   5.26445194995477358631E2,
287       9.34528527171957607540E2,   1.02755188689515710272E3,
288       5.57535335369399327526E2};
289   const double kErfcQCoefficients[] = {
290       1.00000000000000000000E0, 1.32281951154744992508E1,
291       8.67072140885989742329E1, 3.54937778887819891062E2,
292       9.75708501743205489753E2, 1.82390916687909736289E3,
293       2.24633760818710981792E3, 1.65666309194161350182E3,
294       5.57535340817727675546E2};
295   const double kErfcRCoefficients[] = {
296       5.64189583547755073984E-1, 1.27536670759978104416E0,
297       5.01905042251180477414E0,  6.16021097993053585195E0,
298       7.40974269950448939160E0,  2.97886665372100240670E0};
299   const double kErfcSCoefficients[] = {
300       1.00000000000000000000E0, 2.26052863220117276590E0,
301       9.39603524938001434673E0, 1.20489539808096656605E1,
302       1.70814450747565897222E1, 9.60896809063285878198E0,
303       3.36907645100081516050E0};
304 
305   // Let z = -x^2.
306   Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
307   Value z = rewriter.create<mhlo::NegOp>(loc, xSq);
308 
309   // Materialize polynomial approximation for x in [1, 8) as
310   //   erfc(x) = exp(z) P(|x|) / Q(|x|).
311   Value expZ = rewriter.create<mhlo::ExpOp>(loc, z);
312   Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
313   Value polP = materializePolynomialApproximation(
314       rewriter, loc, absX, llvm::makeArrayRef(kErfcPCoefficients));
315   Value expZMulPolyP = rewriter.create<mhlo::MulOp>(loc, expZ, polP);
316   Value polQ = materializePolynomialApproximation(
317       rewriter, loc, absX, llvm::makeArrayRef(kErfcQCoefficients));
318   Value erfcApprox18 = rewriter.create<mhlo::DivOp>(loc, expZMulPolyP, polQ);
319 
320   // Materialize polynomial approximation for x in >= 8 as
321   //   erfc(x) exp(z) R(|x|) / S(|x|).
322   Value polR = materializePolynomialApproximation(
323       rewriter, loc, absX, llvm::makeArrayRef(kErfcRCoefficients));
324   Value expZMulPolyR = rewriter.create<mhlo::MulOp>(loc, expZ, polR);
325   Value polS = materializePolynomialApproximation(
326       rewriter, loc, absX, llvm::makeArrayRef(kErfcSCoefficients));
327   Value erfcApprox8Inf = rewriter.create<mhlo::DivOp>(loc, expZMulPolyR, polS);
328 
329   // Combine polynomial approximations for x >= 1.
330   Value eight = chlo::getConstantLike(rewriter, loc, 8.0, x);
331   Value absXLt8 = rewriter.create<mhlo::CompareOp>(
332       loc, absX, eight, mhlo::ComparisonDirection::LT);
333   Value erfcApprox = rewriter.create<mhlo::SelectOp>(loc, absXLt8, erfcApprox18,
334                                                      erfcApprox8Inf);
335 
336   // Clamp to prevent overflow and materialize approximation for large x as
337   //   erfc(x) = 0.
338   Value zLtNegMaxlog = rewriter.create<mhlo::CompareOp>(
339       loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x),
340       mhlo::ComparisonDirection::LT);
341   Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
342   Value erfcApproxClamped =
343       rewriter.create<mhlo::SelectOp>(loc, zLtNegMaxlog, zero, erfcApprox);
344 
345   // Derive approximation for x <= -1 as
346   //   erfc(x) = 2 - erfc(-x).
347   // Reuse previously materialized approximations all of which take |x| as their
348   // argument.
349   Value xLtZero = rewriter.create<mhlo::CompareOp>(
350       loc, x, zero, mhlo::ComparisonDirection::LT);
351   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
352   Value twoSubErfcApproxClamped =
353       rewriter.create<mhlo::SubtractOp>(loc, two, erfcApproxClamped);
354   return rewriter.create<mhlo::SelectOp>(loc, xLtZero, twoSubErfcApproxClamped,
355                                          erfcApproxClamped);
356 }
357 
358 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
359 // This implementation is based on Cephes.
materializeErfApproximationF64ForMagnituteLeOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)360 Value materializeErfApproximationF64ForMagnituteLeOne(
361     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
362   Value x = args.front();
363   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
364          "expect f64 element type");
365   const double kErfTCoefficients[] = {
366       9.60497373987051638749E0, 9.00260197203842689217E1,
367       2.23200534594684319226E3, 7.00332514112805075473E3,
368       5.55923013010394962768E4};
369   const double kErfUCoefficients[] = {
370       1.00000000000000000000E0, 3.35617141647503099647E1,
371       5.21357949780152679795E2, 4.59432382970980127987E3,
372       2.26290000613890934246E4, 4.92673942608635921086E4};
373 
374   // Materialize polynomial approximation for |x| <= 1 as
375   //   erf(x) = x T(x^2) / U(x^2).
376   Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
377   Value polyT = materializePolynomialApproximation(
378       rewriter, loc, xSq, llvm::makeArrayRef(kErfTCoefficients));
379   Value xMulPolyT = rewriter.create<mhlo::MulOp>(loc, x, polyT);
380   Value polyU = materializePolynomialApproximation(
381       rewriter, loc, xSq, llvm::makeArrayRef(kErfUCoefficients));
382   return rewriter.create<mhlo::DivOp>(loc, xMulPolyT, polyU);
383 }
384 
385 // This implementation is based on Cephes.
materializeErfApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)386 Value materializeErfApproximationF64(ConversionPatternRewriter &rewriter,
387                                      Location loc, ValueRange args) {
388   Value x = args.front();
389   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
390          "expect f64 element type");
391 
392   // Rely on erf approximation for |x| < 1
393   //   erf(x) = erf_approx(x)
394   Value erfApprox =
395       materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
396 
397   // Rely on erfc approximation for |x| >= 1 and materialize erf as
398   //   erf(x) = 1 - erfc_approx(x)
399   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
400   Value erfcApprox =
401       materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
402   Value erfcBasedApprox =
403       rewriter.create<mhlo::SubtractOp>(loc, one, erfcApprox);
404 
405   // Materialize approximation selection based on argument.
406   Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
407   Value absXLtOne = rewriter.create<mhlo::CompareOp>(
408       loc, absX, one, mhlo::ComparisonDirection::LT);
409   return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfApprox,
410                                          erfcBasedApprox);
411 }
412 
materializeErfcApproximationF64(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)413 Value materializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
414                                       Location loc, ValueRange args) {
415   Value x = args.front();
416   assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
417          "expect f64 element type");
418 
419   // Rely on erfc approximation for |x| >= 1
420   //   erfc(x) = erfc_approx(x)
421   Value erfcApprox =
422       materializeErfcApproximationF64ForMagnituteGeOne(rewriter, loc, x);
423 
424   // Rely on erf approximation for |x| < 1 and materialize erfc as
425   //   erfc(x) = 1 - erf_approx(x)
426   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
427   Value erfApprox =
428       materializeErfApproximationF64ForMagnituteLeOne(rewriter, loc, x);
429   Value erfBasedApprox = rewriter.create<mhlo::SubtractOp>(loc, one, erfApprox);
430 
431   // Materialize approximation selection based on argument.
432   Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
433   Value absXLtOne = rewriter.create<mhlo::CompareOp>(
434       loc, absX, one, mhlo::ComparisonDirection::LT);
435   return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfBasedApprox,
436                                          erfcApprox);
437 }
438 
439 // Precondition is |x| >= 1. Use erf approximation, otherwise.
440 //
441 // We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
442 // argument and derive the final approximation for all |x| >= 1.
443 // This implementation is based on Cephes.
materializeErfcApproximationF32ForMagnitudeGeOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)444 Value materializeErfcApproximationF32ForMagnitudeGeOne(
445     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
446   Value x = args.front();
447   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
448          "expect f32 element type");
449   const double kMaxlog = 88.72283905206835;
450   const float kErfcPCoefficients[] = {
451       +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
452       -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
453       +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
454   };
455   const float kErfcRCoefficients[] = {
456       -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
457       +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
458       -2.820767439740514E-1, +5.641895067754075E-1,
459   };
460 
461   // Let z = -x^2.
462   Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
463   Value z = rewriter.create<mhlo::NegOp>(loc, xSq);
464 
465   // Materialize polynomial approximation for x >= 1 as
466   //   erfc(x) = exp(z) 1/x P(1/x^2)   if x in [1, 2)
467   //   erfc(x) = exp(z) 1/x R(1/x^2)   if x >= 2
468   Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
469   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
470   Value reciprocalXSq = rewriter.create<mhlo::DivOp>(loc, one, xSq);
471   Value expZ = rewriter.create<mhlo::ExpOp>(loc, z);
472   Value oneDivAbsX = rewriter.create<mhlo::DivOp>(loc, one, absX);
473   Value expZMulOneDivAbsX = rewriter.create<mhlo::MulOp>(loc, expZ, oneDivAbsX);
474   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
475   Value absXLtTwo = rewriter.create<mhlo::CompareOp>(
476       loc, absX, two, mhlo::ComparisonDirection::LT);
477   Value polP = materializePolynomialApproximation(
478       rewriter, loc, reciprocalXSq, llvm::makeArrayRef(kErfcPCoefficients));
479   Value polR = materializePolynomialApproximation(
480       rewriter, loc, reciprocalXSq, llvm::makeArrayRef(kErfcRCoefficients));
481   Value poly = rewriter.create<mhlo::SelectOp>(loc, absXLtTwo, polP, polR);
482   Value erfcApprox = rewriter.create<mhlo::MulOp>(loc, expZMulOneDivAbsX, poly);
483 
484   // Clamp to prevent overflow and materialize approximation for large x as
485   //   erfc(x) = 0.
486   Value zLtNeqMaxlog = rewriter.create<mhlo::CompareOp>(
487       loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x),
488       mhlo::ComparisonDirection::LT);
489   Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
490   Value erfcApproxClamped =
491       rewriter.create<mhlo::SelectOp>(loc, zLtNeqMaxlog, zero, erfcApprox);
492 
493   // Derive approximation for x <= -1 as
494   //   erfc(x) = 2 - erfc(-x).
495   // Reuse previously materialized approximations all of which take |x| as their
496   // argument.
497   Value xLtZero = rewriter.create<mhlo::CompareOp>(
498       loc, x, zero, mhlo::ComparisonDirection::LT);
499   Value twoSubErfcApprox =
500       rewriter.create<mhlo::SubtractOp>(loc, two, erfcApproxClamped);
501   return rewriter.create<mhlo::SelectOp>(loc, xLtZero, twoSubErfcApprox,
502                                          erfcApproxClamped);
503 }
504 
505 // Precondition is |x| <= 1. Use erfc approximation, otherwise.
506 // This implementation is based on Cephes.
materializeErfApproximationF32ForMagnitudeLeOne(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)507 Value materializeErfApproximationF32ForMagnitudeLeOne(
508     ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
509   Value x = args.front();
510   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
511          "expect f32 element type");
512   const float kErfTCoefficients[] = {
513       +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
514       -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
515       +1.128379165726710E+0,
516   };
517 
518   // Materialize polynomial approximation for |x| <= 1 as
519   //   erf(x) = x T(x^2).
520   Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
521   Value polyT = materializePolynomialApproximation(
522       rewriter, loc, xSq, llvm::makeArrayRef(kErfTCoefficients));
523   return rewriter.create<mhlo::MulOp>(loc, x, polyT);
524 }
525 
526 // This is the same approximation as used in Eigen.
materializeErfApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)527 Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter,
528                                      Location loc, ValueRange args) {
529   Value x = args.front();
530   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
531          "expect f32 element type");
532   const float kAlpha[] = {
533       -2.72614225801306e-10f, 2.77068142495902e-08f,  -2.10102402082508e-06f,
534       -5.69250639462346e-05f, -7.34990630326855e-04f, -2.95459980854025e-03f,
535       -1.60960333262415e-02f,
536   };
537   const float kBeta[] = {
538       -1.45660718464996e-05f, -2.13374055278905e-04f, -1.68282697438203e-03f,
539       -7.37332916720468e-03f, -1.42647390514189e-02f,
540   };
541 
542   // Clamp argument between -4 and 4.
543   Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
544   Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
545   x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
546   Value xSq = rewriter.create<mhlo::MulOp>(loc, x, x);
547 
548   // Materialize polynomial approximation for x in [-4, 4] as
549   //   erf(x) = x * Alpha(x^2) / Beta(x^2).
550   Value alphaPoly = materializePolynomialApproximation(
551       rewriter, loc, xSq, llvm::makeArrayRef(kAlpha));
552   Value betaPoly = materializePolynomialApproximation(
553       rewriter, loc, xSq, llvm::makeArrayRef(kBeta));
554   Value xMulAlphaPoly = rewriter.create<mhlo::MulOp>(loc, x, alphaPoly);
555   return rewriter.create<mhlo::DivOp>(loc, xMulAlphaPoly, betaPoly);
556 }
557 
materializeErfcApproximationF32(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)558 Value materializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
559                                       Location loc, ValueRange args) {
560   Value x = args.front();
561   assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
562          "expect f32 element type");
563 
564   // Rely on erfc approximation for |x| >= 1
565   //   erfc(x) = erfc_approx(x)
566   Value erfcApprox =
567       materializeErfcApproximationF32ForMagnitudeGeOne(rewriter, loc, x);
568 
569   // Rely on erf approximation for |x| < 1 and materialize erfc as
570   //   erfc(x) = 1 - erf_approx(x)
571   Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
572   Value erfApprox =
573       materializeErfApproximationF32ForMagnitudeLeOne(rewriter, loc, x);
574   Value erfBasedApprox = rewriter.create<mhlo::SubtractOp>(loc, one, erfApprox);
575 
576   // Materialize approximation selection based on argument.
577   Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
578   Value absXLtOne = rewriter.create<mhlo::CompareOp>(
579       loc, absX, one, mhlo::ComparisonDirection::LT);
580   return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, erfBasedApprox,
581                                          erfcApprox);
582 }
583 
584 struct ConvertErfOp : public OpConversionPattern<ErfOp> {
585   using OpConversionPattern<ErfOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertErfOp586   LogicalResult matchAndRewrite(
587       ErfOp op, OpAdaptor adaptor,
588       ConversionPatternRewriter &rewriter) const override {
589     Location loc = op.getLoc();
590     Value x = adaptor.operand();
591     Type ty = x.getType().cast<ShapedType>().getElementType();
592 
593     // For now, we support only f64, f32, f16 and bf16.
594     if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16())
595       return failure();
596 
597     if (ty.isF64()) {
598       rewriter.replaceOp(op, materializeErfApproximationF64(rewriter, loc, x));
599       return success();
600     }
601 
602     rewriter.replaceOp(
603         op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
604                                   rewriter.getF32Type(),
605                                   &materializeErfApproximationF32));
606     return success();
607   }
608 };
609 
610 struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
611   using OpConversionPattern<ErfcOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertErfcOp612   LogicalResult matchAndRewrite(
613       ErfcOp op, OpAdaptor adaptor,
614       ConversionPatternRewriter &rewriter) const override {
615     Location loc = op.getLoc();
616     Value x = adaptor.operand();
617     Type ty = x.getType().cast<ShapedType>().getElementType();
618 
619     // For now, we support only f64, f32, f16 and bf16.
620     if (!ty.isF64() && !ty.isF32() && !ty.isF16() && !ty.isBF16())
621       return failure();
622 
623     if (ty.isF64()) {
624       rewriter.replaceOp(op, materializeErfcApproximationF64(rewriter, loc, x));
625       return success();
626     }
627 
628     rewriter.replaceOp(
629         op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
630                                   rewriter.getF32Type(),
631                                   &materializeErfcApproximationF32));
632     return success();
633   }
634 };
635 
636 // Coefficients for the Lanczos approximation of the gamma function. The
637 // coefficients are uniquely determined by the choice of g and n (kLanczosGamma
638 // and kLanczosCoefficients.size() + 1). The coefficients below correspond to
639 // [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
640 // [7, 9] seemed to be the least sensitive to the quality of the log function.
641 // In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
642 // for a particularly inaccurate log function.
643 constexpr double kLanczosGamma = 7;  // aka g
644 constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
645 constexpr std::array<double, 8> kLanczosCoefficients = {
646     676.520368121885098567009190444019, -1259.13921672240287047156078755283,
647     771.3234287776530788486528258894,   -176.61502916214059906584551354,
648     12.507343278686904814458936853,     -0.13857109526572011689554707,
649     9.984369578019570859563e-6,         1.50563273514931155834e-7};
650 
651 // Compute the Lgamma function using Lanczos' approximation from "A Precision
652 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
653 // series B. Vol. 1:
654 //   lgamma(z + 1) = (log(2) + log(pi)) / 2
655 //                     + (z + 1/2) * log(t(z))
656 //                     - t(z) + log(a(z))
657 //   with   t(z) = z + kLanczosGamma + 1/2
658 //          a(z) = kBaseLanczosCoeff
659 //                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
materializeLgamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)660 Value materializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
661                         ValueRange args) {
662   // If the input is less than 0.5 use Euler's reflection formula.
663   //   gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
664   // Let z be
665   //   z = -x      if x < 1/2
666   //   z = x - 1   otheriwse
667   Value x = args.front();
668   Value half = getConstantLike(rewriter, loc, 0.5, x);
669   Value needToReflect = rewriter.create<mhlo::CompareOp>(
670       loc, x, half, mhlo::ComparisonDirection::LT);
671   Value negX = rewriter.create<mhlo::NegOp>(loc, x);
672   Value one = getConstantLike(rewriter, loc, 1, x);
673   Value xSubOne = rewriter.create<mhlo::SubtractOp>(loc, x, one);
674   Value z = rewriter.create<mhlo::SelectOp>(loc, needToReflect, negX, xSubOne);
675 
676   // Materialize
677   //   a(z) = kBaseLanczosCoeff
678   //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
679   Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
680   for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
681     Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
682     Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
683     Value quotient = rewriter.create<mhlo::DivOp>(
684         loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, oneBasedIndex));
685     a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
686   }
687 
688   // To improve accuracy on platforms with less-precise log implementations,
689   // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
690   // device.
691   // Materialize as
692   //   log(t) = log(kLanczosGamma + 1/2 + z)
693   //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
694   Value lanczosPlusHalf =
695       getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
696   Value t = rewriter.create<mhlo::AddOp>(loc, lanczosPlusHalf, z);
697   Value logTerm =
698       getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
699   Value log1pTerm = rewriter.create<mhlo::Log1pOp>(
700       loc, rewriter.create<mhlo::DivOp>(loc, z, lanczosPlusHalf));
701   Value logT = rewriter.create<mhlo::AddOp>(loc, logTerm, log1pTerm);
702 
703   // Note that t(z) may be large and we need to be careful not to overflow to
704   // infinity in the relevant term
705   //   r = (z + 1/2) * log(t(z)) - t(z).
706   // Therefore, we compute this as
707   //   r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
708   Value tDivLogT = rewriter.create<mhlo::DivOp>(loc, t, logT);
709   Value sum = rewriter.create<mhlo::SubtractOp>(
710       loc, rewriter.create<mhlo::AddOp>(loc, z, half), tDivLogT);
711   Value r = rewriter.create<mhlo::MulOp>(loc, sum, logT);
712 
713   // Compute the final result (modulo reflection) as
714   //   lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
715   Value logA = rewriter.create<mhlo::LogOp>(loc, a);
716   Value lgamma = rewriter.create<mhlo::AddOp>(
717       loc,
718       rewriter.create<mhlo::AddOp>(
719           loc,
720           getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
721           r),
722       logA);
723 
724   // Compute the reflected value for x < 0.5 as
725   //   lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
726   //
727   // The abs is needed because lgamma is the log of the absolute value of the
728   // gamma function.
729   //
730   // We have to be careful when computing the final term above. gamma(x) goes
731   // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
732   // term. The slope is large, so precision is particularly important.
733   //
734   // Because abs(sin(pi * x)) has period of 1 we can equivalently use
735   // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
736   // more numerically accurate: It doesn't overflow to inf like pi * x would and
737   // if x is an integer it evaluates to exactly 0 which is important because we
738   // then take the log of this value, and log(0) is inf.
739   //
740   // We don't have a frac(x) primitive in HLO and computing it is tricky, but
741   // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
742   // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
743   //
744   // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
745   // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
746   // [0, 1] is symmetric across the line Y=0.5.
747   //
748 
749   // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
750   // pi * abs_frac for values of abs_frac close to 1.
751   Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
752   Value absFrac = rewriter.create<mhlo::SubtractOp>(
753       loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
754   Value reduceAbsFrac = rewriter.create<mhlo::CompareOp>(
755       loc, half, absFrac, mhlo::ComparisonDirection::LT);
756   absFrac = rewriter.create<mhlo::SelectOp>(
757       loc, reduceAbsFrac, rewriter.create<mhlo::SubtractOp>(loc, one, absFrac),
758       absFrac);
759 
760   // Materialize reflection.
761   Value reflectionDenom = rewriter.create<mhlo::LogOp>(
762       loc,
763       rewriter.create<mhlo::SineOp>(
764           loc, rewriter.create<mhlo::MulOp>(
765                    loc, getConstantLike(rewriter, loc, M_PI, x), absFrac)));
766   Value lgammaReflection = rewriter.create<mhlo::SubtractOp>(
767       loc,
768       rewriter.create<mhlo::SubtractOp>(
769           loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
770           reflectionDenom),
771       lgamma);
772 
773   // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
774   // then it "wins" and the result is +/-inf.
775   Value finiteReflectionDenom =
776       rewriter.create<mhlo::IsFiniteOp>(loc, reflectionDenom);
777   Value negReflectionDenom = rewriter.create<mhlo::NegOp>(loc, reflectionDenom);
778   lgammaReflection = rewriter.create<mhlo::SelectOp>(
779       loc, finiteReflectionDenom, lgammaReflection, negReflectionDenom);
780 
781   // Select whether or not to rely on the reflection.
782   lgamma = rewriter.create<mhlo::SelectOp>(loc, needToReflect, lgammaReflection,
783                                            lgamma);
784 
785   // Materialize +/-inf behavior as
786   //   lgamma(+/-inf) = +inf.
787   Value xIsInf = rewriter.create<chlo::IsInfOp>(loc, x);
788   return rewriter.create<mhlo::SelectOp>(
789       loc, xIsInf,
790       chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
791       lgamma);
792 }
793 
794 // Express `cosh` as
795 //   cosh(x) = (e^x + e^-x) / 2
796 //           = e^(x + log(1/2)) + e^(-x + log(1/2))
797 //
798 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
799 //
800 // This incorrectly overflows to inf for two f32 input values, namely
801 // +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
802 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
803 // we deem this acceptable.
materializeCoshApproximation(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)804 Value materializeCoshApproximation(ConversionPatternRewriter &rewriter,
805                                    Location loc, ValueRange operands) {
806   CoshOp::Adaptor transformed(operands);
807   Value x = transformed.operand();
808 
809   Value logOneHalf =
810       rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
811   Value expAdd = rewriter.create<mhlo::ExpOp>(
812       loc, rewriter.create<mhlo::AddOp>(loc, x, logOneHalf));
813   Value expSub = rewriter.create<mhlo::ExpOp>(
814       loc, rewriter.create<mhlo::SubtractOp>(loc, logOneHalf, x));
815   return rewriter.create<mhlo::AddOp>(loc, expAdd, expSub);
816 }
817 
818 struct ConvertCoshOp : public OpConversionPattern<CoshOp> {
819   using OpConversionPattern<CoshOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertCoshOp820   LogicalResult matchAndRewrite(
821       CoshOp op, OpAdaptor adaptor,
822       ConversionPatternRewriter &rewriter) const override {
823     rewriter.replaceOp(
824         op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
825                                   rewriter.getF32Type(),
826                                   &materializeCoshApproximation));
827     return success();
828   }
829 };
830 
831 // Compute the Digamma function using Lanczos' approximation from "A Precision
832 // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
833 // series B. Vol. 1:
834 //   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
835 //   with   t(z) = z + kLanczosGamma + 1/2
836 //          a(z) = kBaseLanczosCoeff
837 //                   + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
838 //          a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
materializeDigamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)839 Value materializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
840                          ValueRange args) {
841   // If the input is less than 0.5 use Euler's reflection formula.
842   //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
843   // Let z be
844   //   z = -x      if x < 1/2
845   //   z = x - 1   otheriwse
846   Value x = args.front();
847   Value half = getConstantLike(rewriter, loc, 0.5, x);
848   Value needToReflect = rewriter.create<mhlo::CompareOp>(
849       loc, x, half, mhlo::ComparisonDirection::LT);
850   Value negX = rewriter.create<mhlo::NegOp>(loc, x);
851   Value one = getConstantLike(rewriter, loc, 1, x);
852   Value xSubOne = rewriter.create<mhlo::SubtractOp>(loc, x, one);
853   Value z = rewriter.create<mhlo::SelectOp>(loc, needToReflect, negX, xSubOne);
854 
855   // Materialize
856   //   a(z) = kBaseLanczosCoeff
857   //            + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
858   //   a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
859   Value zero = getConstantLike(rewriter, loc, 0.0, x);
860   Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
861   Value aPrime = zero;
862   for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
863     Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
864     Value oneBasedIndex = getConstantLike(rewriter, loc, i + 1, x);
865     Value zTerm = rewriter.create<mhlo::AddOp>(loc, z, oneBasedIndex);
866     aPrime = rewriter.create<mhlo::SubtractOp>(
867         loc, aPrime,
868         rewriter.create<mhlo::DivOp>(
869             loc, coeff, rewriter.create<mhlo::MulOp>(loc, zTerm, zTerm)));
870     a = rewriter.create<mhlo::AddOp>(
871         loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, zTerm));
872   }
873 
874   // To improve accuracy on platforms with less-precise log implementations,
875   // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
876   // device.
877   // Materialize as
878   //   log(t) = log(kLanczosGamma + 1/2 + z)
879   //          = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
880   Value lanczosPlusHalf =
881       getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
882   Value t = rewriter.create<mhlo::AddOp>(loc, lanczosPlusHalf, z);
883   Value logTerm =
884       getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
885   Value log1pTerm = rewriter.create<mhlo::Log1pOp>(
886       loc, rewriter.create<mhlo::DivOp>(loc, z, lanczosPlusHalf));
887   Value logT = rewriter.create<mhlo::AddOp>(loc, logTerm, log1pTerm);
888 
889   // Materialize the final result (modulo reflection) as
890   //   digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
891   Value aPrimeDivA = rewriter.create<mhlo::DivOp>(loc, aPrime, a);
892   Value lanczosGammaDivT = rewriter.create<mhlo::DivOp>(
893       loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
894   Value digamma = rewriter.create<mhlo::SubtractOp>(
895       loc, rewriter.create<mhlo::AddOp>(loc, logT, aPrimeDivA),
896       lanczosGammaDivT);
897 
898   // We need to be careful how we compute cot(pi * input) below: For
899   // near-integral arguments, pi * input can lose precision.
900   //
901   // Input is already known to be less than 0.5 (otherwise we don't have to
902   // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
903   // increase precision of pi * x and the resulting cotangent.
904   Value reducedX = rewriter.create<mhlo::AddOp>(
905       loc, x,
906       rewriter.create<mhlo::AbsOp>(
907           loc, rewriter.create<mhlo::FloorOp>(
908                    loc, rewriter.create<mhlo::AddOp>(
909                             loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
910 
911   // Materialize reflection for inputs less than 0.5 as
912   //   digamma(x) = digamma(1 - x) - pi * cot(pi * x)
913   //              = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
914   Value pi = getConstantLike(rewriter, loc, M_PI, x);
915   Value piMulReducedX = rewriter.create<mhlo::MulOp>(loc, pi, reducedX);
916   Value cos = rewriter.create<mhlo::CosineOp>(loc, piMulReducedX);
917   Value sin = rewriter.create<mhlo::SineOp>(loc, piMulReducedX);
918   Value reflection = rewriter.create<mhlo::SubtractOp>(
919       loc, digamma,
920       rewriter.create<mhlo::DivOp>(
921           loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
922 
923   // Select whether or not to rely on the reflection.
924   digamma =
925       rewriter.create<mhlo::SelectOp>(loc, needToReflect, reflection, digamma);
926 
927   // Digamma has poles at negative integers and zero; return nan for those.
928   Value isLeZero = rewriter.create<mhlo::CompareOp>(
929       loc, x, zero, mhlo::ComparisonDirection::LE);
930   Value isInt = rewriter.create<mhlo::CompareOp>(
931       loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
932       mhlo::ComparisonDirection::EQ);
933   Value isPole = rewriter.create<mhlo::AndOp>(loc, isLeZero, isInt);
934   return rewriter.create<mhlo::SelectOp>(
935       loc, isPole,
936       getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
937                       x),
938       digamma);
939 }
940 
materializeZeta(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)941 Value materializeZeta(ConversionPatternRewriter &rewriter, Location loc,
942                       ValueRange args) {
943   assert(args.size() == 2);
944   Value x = args[0];
945   Value q = args[1];
946   static const std::array<double, 12> kZetaCoeffs{
947       -7.1661652561756670113e18,
948       1.8152105401943546773e17,
949       -4.5979787224074726105e15,
950       1.1646782814350067249e14,
951       -2.950130727918164224e12,
952       7.47242496e10,
953       -1.8924375803183791606e9,
954       47900160.0,
955       -1209600.0,
956       30240.0,
957       -720.0,
958       12.0,
959   };
960 
961   // For speed we'll always use 9 iterations for the initial series estimate,
962   // and a 12 term expansion for the Euler-Maclaurin formula.
963   Value a = q;
964   Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
965   Value negPower = zero;
966   Value negX = rewriter.create<mhlo::NegOp>(loc, x);
967   Value initialSum = rewriter.create<mhlo::PowOp>(loc, q, negX);
968   Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
969   for (int i = 0; i < 9; ++i) {
970     a = rewriter.create<mhlo::AddOp>(loc, a, one);
971     negPower = rewriter.create<mhlo::PowOp>(loc, a, negX);
972     initialSum = rewriter.create<mhlo::AddOp>(loc, initialSum, negPower);
973   }
974   a = rewriter.create<mhlo::AddOp>(loc, a, one);
975   negPower = rewriter.create<mhlo::PowOp>(loc, a, negX);
976   Value oneLikeX = chlo::getConstantLike(rewriter, loc, 1.0, x);
977   Value xMinusOne = rewriter.create<mhlo::SubtractOp>(loc, x, oneLikeX);
978   Value negPowerMulA = rewriter.create<mhlo::MulOp>(loc, negPower, a);
979   Value negPowerMulADivXMinusOne =
980       rewriter.create<mhlo::DivOp>(loc, negPowerMulA, xMinusOne);
981   Value s =
982       rewriter.create<mhlo::AddOp>(loc, initialSum, negPowerMulADivXMinusOne);
983   Value aInverseSquare = rewriter.create<mhlo::DivOp>(
984       loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
985 
986   Value hornerSum = zero;
987   Value factor = one;
988   // Use Horner's rule for this.
989   // Note this differs from Cephes which does a 'naive' polynomial evaluation.
990   // Using Horner's rule allows to avoid some NaN's and Infs from happening,
991   // resulting in more numerically stable code.
992   for (int i = 0; i < 11; ++i) {
993     Value factorLhs = rewriter.create<mhlo::SubtractOp>(
994         loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
995     Value factorRhs = rewriter.create<mhlo::SubtractOp>(
996         loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
997     factor = rewriter.create<mhlo::MulOp>(loc, factorLhs, factorRhs);
998     hornerSum = rewriter.create<mhlo::MulOp>(
999         loc, factor,
1000         rewriter.create<mhlo::MulOp>(
1001             loc, aInverseSquare,
1002             rewriter.create<mhlo::AddOp>(
1003                 loc, hornerSum,
1004                 chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
1005   }
1006   Value zeroPointFiveLikeNegPower =
1007       chlo::getConstantLike(rewriter, loc, .5, negPower);
1008   Value xDivA = rewriter.create<mhlo::DivOp>(loc, x, a);
1009   s = rewriter.create<mhlo::AddOp>(
1010       loc, s,
1011       rewriter.create<mhlo::MulOp>(
1012           loc, negPower,
1013           rewriter.create<mhlo::AddOp>(
1014               loc, zeroPointFiveLikeNegPower,
1015               rewriter.create<mhlo::MulOp>(
1016                   loc, xDivA,
1017                   rewriter.create<mhlo::AddOp>(
1018                       loc,
1019                       chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
1020                                             a),
1021                       hornerSum)))));
1022 
1023   // Use the initial zeta sum without the correction term coming
1024   // from Euler-Maclaurin if it is accurate enough.
1025   Value absNegPower = rewriter.create<mhlo::AbsOp>(loc, negPower);
1026   Value absInitialSum = rewriter.create<mhlo::AbsOp>(loc, initialSum);
1027   Value output = rewriter.create<mhlo::SelectOp>(
1028       loc,
1029       rewriter.create<mhlo::CompareOp>(
1030           loc, absNegPower,
1031           rewriter.create<mhlo::MulOp>(
1032               loc, absInitialSum,
1033               chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
1034           mhlo::ComparisonDirection::LT),
1035       initialSum, s);
1036 
1037   // Function is not defined for x < 1.
1038   Value nan = chlo::getConstantLike(
1039       rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
1040   output = rewriter.create<mhlo::SelectOp>(
1041       loc,
1042       rewriter.create<mhlo::CompareOp>(loc, x, oneLikeX,
1043                                        mhlo::ComparisonDirection::LT),
1044       nan, output);
1045 
1046   // For q <= 0, x must be an integer.
1047   Value qLeZero = rewriter.create<mhlo::CompareOp>(
1048       loc, q, zero, mhlo::ComparisonDirection::LE);
1049   Value xNotInt = rewriter.create<mhlo::CompareOp>(
1050       loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
1051       mhlo::ComparisonDirection::NE);
1052   Value xDomainError = rewriter.create<mhlo::AndOp>(loc, qLeZero, xNotInt);
1053   output = rewriter.create<mhlo::SelectOp>(loc, xDomainError, nan, output);
1054 
1055   // For all integer q <= 0, zeta has a pole. The limit is only defined as
1056   // +inf if x is and even integer.
1057   Value inf = chlo::getConstantLike(rewriter, loc,
1058                                     std::numeric_limits<double>::infinity(), x);
1059   Value qIsInt = rewriter.create<mhlo::CompareOp>(
1060       loc, q, rewriter.create<mhlo::FloorOp>(loc, q),
1061       mhlo::ComparisonDirection::EQ);
1062   Value atPole = rewriter.create<mhlo::AndOp>(loc, qLeZero, qIsInt);
1063   Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
1064   Value xIsInt = rewriter.create<mhlo::CompareOp>(
1065       loc, x, rewriter.create<mhlo::FloorOp>(loc, x),
1066       mhlo::ComparisonDirection::EQ);
1067   Value xIsEven = rewriter.create<mhlo::CompareOp>(
1068       loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero,
1069       mhlo::ComparisonDirection::EQ);
1070   Value xIsEvenInt = rewriter.create<mhlo::AndOp>(loc, xIsInt, xIsEven);
1071   output = rewriter.create<mhlo::SelectOp>(
1072       loc, atPole, rewriter.create<mhlo::SelectOp>(loc, xIsEvenInt, inf, nan),
1073       output);
1074 
1075   // For x = 1, this is the harmonic series and diverges.
1076   output = rewriter.create<mhlo::SelectOp>(
1077       loc,
1078       rewriter.create<mhlo::CompareOp>(loc, x, one,
1079                                        mhlo::ComparisonDirection::EQ),
1080       inf, output);
1081 
1082   return output;
1083 }
1084 
materializePolygamma(ConversionPatternRewriter & rewriter,Location loc,ValueRange args)1085 Value materializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
1086                            ValueRange args) {
1087   PolygammaOp::Adaptor transformed(args);
1088   Value n = transformed.n();
1089   Value x = transformed.x();
1090 
1091   // Handle integer n > 0.
1092   Value one = getConstantLike(rewriter, loc, 1.0, x);
1093   Value two = getConstantLike(rewriter, loc, 2.0, x);
1094   Value sign = rewriter.create<mhlo::SubtractOp>(
1095       loc,
1096       rewriter.create<mhlo::MulOp>(loc, two,
1097                                    rewriter.create<mhlo::RemOp>(loc, n, two)),
1098       one);
1099   Value nPlusOne = rewriter.create<mhlo::AddOp>(loc, n, one);
1100   Value expLgammaNp1 = rewriter.create<mhlo::ExpOp>(
1101       loc, rewriter.create<chlo::LgammaOp>(loc, nPlusOne));
1102   Value zeta = rewriter.create<chlo::ZetaOp>(loc, nPlusOne, x);
1103   Value result = rewriter.create<mhlo::MulOp>(
1104       loc, rewriter.create<mhlo::MulOp>(loc, sign, expLgammaNp1), zeta);
1105 
1106   // Handle n = 0.
1107   Value zero = getConstantLike(rewriter, loc, 0.0, x);
1108   Value nEqZero = rewriter.create<mhlo::CompareOp>(
1109       loc, n, zero, mhlo::ComparisonDirection::EQ);
1110   result = rewriter.create<mhlo::SelectOp>(
1111       loc, nEqZero, rewriter.create<chlo::DigammaOp>(loc, x), result);
1112 
1113   // Check that n is a natural number. Return nan, otherwise.
1114   Value nonInt = rewriter.create<mhlo::CompareOp>(
1115       loc, n, rewriter.create<mhlo::FloorOp>(loc, n),
1116       mhlo::ComparisonDirection::NE);
1117   Value negative = rewriter.create<mhlo::CompareOp>(
1118       loc, n, zero, mhlo::ComparisonDirection::LT);
1119   Value nonNatural = rewriter.create<mhlo::OrOp>(loc, nonInt, negative);
1120   return rewriter.create<mhlo::SelectOp>(
1121       loc, nonNatural,
1122       getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
1123                       x),
1124       result);
1125 }
1126 
1127 struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
1128   using OpConversionPattern<LgammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertLgammaOp1129   LogicalResult matchAndRewrite(
1130       LgammaOp op, OpAdaptor adaptor,
1131       ConversionPatternRewriter &rewriter) const override {
1132     FloatType minPrecisionTy = rewriter.getF32Type();
1133     rewriter.replaceOp(
1134         op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
1135                                   minPrecisionTy, &materializeLgamma));
1136     return success();
1137   }
1138 };
1139 
1140 struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
1141   using OpConversionPattern<DigammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertDigammaOp1142   LogicalResult matchAndRewrite(
1143       DigammaOp op, OpAdaptor adaptor,
1144       ConversionPatternRewriter &rewriter) const override {
1145     FloatType minPrecisionTy = rewriter.getF32Type();
1146     rewriter.replaceOp(
1147         op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
1148                                   minPrecisionTy, &materializeDigamma));
1149     return success();
1150   }
1151 };
1152 
materializeNextAfter(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1153 Value materializeNextAfter(ConversionPatternRewriter &rewriter, Location loc,
1154                            ValueRange operands) {
1155   NextAfterOp::Adaptor transformed(operands);
1156   Value x = transformed.x();
1157   Value y = transformed.y();
1158   auto resultTy = x.getType().cast<ShapedType>();
1159   auto bitwidth = resultTy.getElementType().getIntOrFloatBitWidth();
1160   ImplicitLocOpBuilder b(loc, rewriter);
1161   auto intTy = resultTy.clone(b.getIntegerType(bitwidth));
1162   auto xAsInt = b.create<mhlo::BitcastConvertOp>(intTy, x);
1163   auto yAsInt = b.create<mhlo::BitcastConvertOp>(intTy, y);
1164 
1165   // The result is NaN if either "x" or "y" are NaN.
1166   auto xIsNan = b.create<mhlo::CompareOp>(x, x, mhlo::ComparisonDirection::NE);
1167   auto yIsNan = b.create<mhlo::CompareOp>(y, y, mhlo::ComparisonDirection::NE);
1168   auto nanInput = b.create<mhlo::OrOp>(xIsNan, yIsNan);
1169   auto resultForNan = getConstantLike(
1170       rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
1171   auto resultForNanAsInt =
1172       b.create<mhlo::BitcastConvertOp>(intTy, resultForNan);
1173 
1174   // The sign bit is the MSB.
1175   const int64_t signBit = int64_t{1} << (bitwidth - 1);
1176   // Discard the sign bit to make the result non-negative.
1177   auto signMask = getConstantLike(rewriter, loc, signBit, xAsInt);
1178   auto negatedSignMask = getConstantLike(rewriter, loc, ~signBit, xAsInt);
1179   auto xAbs = b.create<mhlo::AndOp>(xAsInt, negatedSignMask);
1180   auto yAbs = b.create<mhlo::AndOp>(yAsInt, negatedSignMask);
1181 
1182   // When both "x" and "y" are equal, the result is "y".
1183   auto xAndYAreEqual =
1184       b.create<mhlo::CompareOp>(x, y, mhlo::ComparisonDirection::EQ);
1185   auto resultForEqual = yAsInt;
1186 
1187   // When both "x" and "y" are 0, the result is "y". This is a separate case
1188   // from above because "x" and "y" might have a different sign.
1189   auto zero = getConstantLike(rewriter, loc, 0, xAsInt);
1190   auto xIsZero =
1191       b.create<mhlo::CompareOp>(xAbs, zero, mhlo::ComparisonDirection::EQ);
1192   auto yIsZero =
1193       b.create<mhlo::CompareOp>(yAbs, zero, mhlo::ComparisonDirection::EQ);
1194   auto resultForBothZero = yAsInt;
1195 
1196   auto xSign = b.create<mhlo::AndOp>(xAsInt, signMask);
1197   auto ySign = b.create<mhlo::AndOp>(yAsInt, signMask);
1198 
1199   // If from == 0 && to != 0, we need to return the smallest subnormal number
1200   // signed like "to".
1201   auto one = getConstantLike(rewriter, loc, 1, xAsInt);
1202   auto resultForXZeroYNonZero = b.create<mhlo::OrOp>(ySign, one);
1203 
1204   // If the sign of "x" and "y" disagree:
1205   // - we need to make the magnitude of "from" smaller so that it is closer to
1206   //   zero.
1207   //
1208   // Otherwise the signs agree:
1209   // - "x" with a magnitude larger than "y" means we need to make the magnitude
1210   //   smaller.
1211   // - "x" with a magnitude smaller than "y" means we need to make the magnitude
1212   //   larger.
1213   auto signsDisagree =
1214       b.create<mhlo::CompareOp>(xSign, ySign, mhlo::ComparisonDirection::NE);
1215   auto xMagnitudeLargerThanY =
1216       b.create<mhlo::CompareOp>(xAbs, yAbs, mhlo::ComparisonDirection::GT);
1217   auto resultHasSmallerMagnitude =
1218       b.create<mhlo::OrOp>(xMagnitudeLargerThanY, signsDisagree);
1219   auto minusOne = getConstantLike(rewriter, loc, -1, xAsInt);
1220   auto magnitudeAdjustment =
1221       b.create<mhlo::SelectOp>(resultHasSmallerMagnitude, minusOne, one);
1222   Value result = b.create<mhlo::AddOp>(xAsInt, magnitudeAdjustment);
1223   // Handle from == +-0.
1224   result = b.create<mhlo::SelectOp>(
1225       xIsZero,
1226       b.create<mhlo::SelectOp>(yIsZero, resultForBothZero,
1227                                resultForXZeroYNonZero),
1228       result);
1229   // Handle from == to.
1230   result = b.create<mhlo::SelectOp>(xAndYAreEqual, resultForEqual, result);
1231   // Handle isnan(x) || isnan(y).
1232   result = b.create<mhlo::SelectOp>(nanInput, resultForNanAsInt, result);
1233 
1234   // Cast back to the original type.
1235   return b.create<mhlo::BitcastConvertOp>(resultTy, result);
1236 }
1237 
1238 struct ConvertNextAfterOp : public OpConversionPattern<NextAfterOp> {
1239   using OpConversionPattern<NextAfterOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertNextAfterOp1240   LogicalResult matchAndRewrite(
1241       NextAfterOp op, OpAdaptor adaptor,
1242       ConversionPatternRewriter &rewriter) const override {
1243     rewriter.replaceOp(
1244         op, materializeNextAfter(rewriter, op.getLoc(), adaptor.getOperands()));
1245     return success();
1246   }
1247 };
1248 
1249 struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
1250   using OpConversionPattern<PolygammaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertPolygammaOp1251   LogicalResult matchAndRewrite(
1252       PolygammaOp op, OpAdaptor adaptor,
1253       ConversionPatternRewriter &rewriter) const override {
1254     Location loc = op.getLoc();
1255     FloatType minPrecisionTy = rewriter.getF32Type();
1256     rewriter.replaceOp(
1257         op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
1258                                   minPrecisionTy, &materializePolygamma));
1259     return success();
1260   }
1261 };
1262 
1263 // Sinh(x) = (e^x - e^-x) / 2
1264 //         = e^(x + log(1/2)) - e^(-x + log(1/2)).
1265 //
1266 // The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not
1267 // inf.
1268 //
1269 // This incorrectly overflows to +/-inf for two f32 input values, namely
1270 // +/-89.4159851, due to rounding error when computing x +/- log(1/2).  The
1271 // correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
1272 // we deem this acceptable.
materializeSinhApproximationForLargeX(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1273 Value materializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
1274                                             Location loc, ValueRange operands) {
1275   SinhOp::Adaptor transformed(operands);
1276   Value x = transformed.operand();
1277 
1278   Value logOneHalf =
1279       rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
1280   Value expAdd = rewriter.create<mhlo::ExpOp>(
1281       loc, rewriter.create<mhlo::AddOp>(loc, x, logOneHalf));
1282   Value expSub = rewriter.create<mhlo::ExpOp>(
1283       loc, rewriter.create<mhlo::SubtractOp>(loc, logOneHalf, x));
1284   return rewriter.create<mhlo::SubtractOp>(loc, expAdd, expSub);
1285 }
1286 
1287 // Express `sinh` as
1288 //   sinh(x) = (e^x - e^-x) / 2                     if |x| < 1
1289 //           = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
materializeSinhApproximation(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1290 Value materializeSinhApproximation(ConversionPatternRewriter &rewriter,
1291                                    Location loc, ValueRange operands) {
1292   Value largeSinhResult =
1293       materializeSinhApproximationForLargeX(rewriter, loc, operands);
1294 
1295   SinhOp::Adaptor transformed(operands);
1296   Value x = transformed.operand();
1297 
1298   // For smaller x, we get unwanted cancellations of e^x - e^-x, resulting in
1299   // 0.
1300   // Rewrite this to avoid that. We use expm1(x) because that preserves the
1301   // first order term of the taylor series of e^x.
1302   // (e^(x) - e^(-x)) / 2. =
1303   // (e^(x) - 1 + 1 - e^(-x)) / 2.
1304   // (expm1(x) + (e^(x) - 1) / e^x) / 2.
1305   // (expm1(x) + expm1(x) / (expm1(x) + 1)) / 2.
1306   Value expm1 = rewriter.create<mhlo::Expm1Op>(loc, x);
1307   Value one = getConstantLike(rewriter, loc, 1.0, x);
1308   Value oneHalf = getConstantLike(rewriter, loc, 0.5, x);
1309   Value expm1PlusOne = rewriter.create<mhlo::AddOp>(loc, expm1, one);
1310   Value ratio = rewriter.create<mhlo::DivOp>(loc, expm1, expm1PlusOne);
1311   Value sum = rewriter.create<mhlo::AddOp>(loc, expm1, ratio);
1312   Value smallSinhResult = rewriter.create<mhlo::MulOp>(loc, oneHalf, sum);
1313 
1314   Value absX = rewriter.create<mhlo::AbsOp>(loc, x);
1315   Value absXLtOne = rewriter.create<mhlo::CompareOp>(
1316       loc, absX, one, mhlo::ComparisonDirection::LT);
1317   return rewriter.create<mhlo::SelectOp>(loc, absXLtOne, smallSinhResult,
1318                                          largeSinhResult);
1319 }
1320 
1321 struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
1322   using OpConversionPattern<SinhOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertSinhOp1323   LogicalResult matchAndRewrite(
1324       SinhOp op, OpAdaptor adaptor,
1325       ConversionPatternRewriter &rewriter) const override {
1326     Value x = adaptor.operand();
1327     if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
1328       rewriter.replaceOp(op, materializeSinhApproximationForLargeX(
1329                                  rewriter, op.getLoc(), adaptor.getOperands()));
1330       return success();
1331     }
1332     rewriter.replaceOp(
1333         op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
1334                                   rewriter.getF32Type(),
1335                                   &materializeSinhApproximation));
1336     return success();
1337   }
1338 };
1339 
materializeTan(ConversionPatternRewriter & rewriter,Location loc,ValueRange operands)1340 Value materializeTan(ConversionPatternRewriter &rewriter, Location loc,
1341                      ValueRange operands) {
1342   TanOp::Adaptor transformed(operands);
1343   return rewriter.create<mhlo::DivOp>(
1344       loc, rewriter.create<mhlo::SineOp>(loc, transformed.operand()),
1345       rewriter.create<mhlo::CosineOp>(loc, transformed.operand()));
1346 }
1347 
1348 struct ConvertTanOp : public OpConversionPattern<TanOp> {
1349   using OpConversionPattern<TanOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertTanOp1350   LogicalResult matchAndRewrite(
1351       TanOp op, OpAdaptor adaptor,
1352       ConversionPatternRewriter &rewriter) const override {
1353     rewriter.replaceOp(
1354         op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(),
1355                                   rewriter.getF32Type(), &materializeTan));
1356     return success();
1357   }
1358 };
1359 
1360 // Converts chlo.top_k to MHLO iota, sort, and slice ops.
1361 //
1362 // chlo.top_k sorts along last dimension of the input tensor and then returns
1363 // the top K components' values and indices. This is translated into a few
1364 // ops in MHLO: first generating an integer sequence for the indices,
1365 // then sort both the original input tensor and the indices togheter, and
1366 // at last slice out the top K components.
1367 //
1368 // For example, for the following IR:
1369 //
1370 // %0:2 = "chlo.top_k"(%input, k=8): tensor<16x16xf32> ->
1371 //                                   (tensor<16x8xf32>, tensor<16x8xi32>)
1372 //
1373 // We will get:
1374 //
1375 // %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
1376 // %2 = "mhlo.sort"(%input, %1) ({
1377 // ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
1378 //      %arg3: tensor<i32>, %arg4: tensor<i32>):
1379 //   %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
1380 //   "mhlo.return"(%7) : (tensor<i1>) -> ()
1381 // }) {dimension = 1 : i64, is_stable = true} : ...
1382 // %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ...
1383 // %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ...
1384 // %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
1385 //                           start_indices dense<0> : tensor<2xi64>,
1386 //                           strides = dense<1> : tensor<2xi64>} :
1387 //                              (tensor<16x16xf32>) -> tensor<16x8xf32>
1388 // %6 = "mhlo.slice"(%4) ...
1389 struct ConvertTopKOp : public OpConversionPattern<TopKOp> {
1390   using OpConversionPattern<TopKOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertTopKOp1391   LogicalResult matchAndRewrite(
1392       TopKOp op, OpAdaptor /*adaptor*/,
1393       ConversionPatternRewriter &rewriter) const override {
1394     // The last dimension of the operand's shape should be known so we can have
1395     // clamped end_indices for slices. This is verified by the op.
1396     auto operandType = op.operand().getType().cast<RankedTensorType>();
1397     int64_t operandRank = operandType.getRank();
1398     int64_t lastDimIndex = operandRank - 1;
1399     int64_t lastDimSize = operandType.getDimSize(lastDimIndex);
1400     assert(lastDimSize != ShapedType::kDynamicSize);
1401 
1402     // Create an Iota op for indices.
1403     auto i32Type = rewriter.getIntegerType(32);
1404     Type iotaType = RankedTensorType::get(operandType.getShape(), i32Type);
1405     Value iotaOp = rewriter.create<mhlo::IotaOp>(
1406         op.getLoc(), iotaType, rewriter.getI64IntegerAttr(lastDimIndex));
1407 
1408     // Create the sort op. It takes two inputs, one for the original input, the
1409     // other for the indices. Use TOTALORDER comparison type instead of the
1410     // default comparison if the element type is of type float.
1411     Type elementType = operandType.getElementType();
1412     auto sortOp = createSortOp(&rewriter, op.getLoc(), {op.operand(), iotaOp},
1413                                {elementType, i32Type}, lastDimIndex,
1414                                /*isStable=*/true,
1415                                /*direction=*/mhlo::ComparisonDirection::GT);
1416 
1417     // Get the sorted input and index tuple element.
1418     auto tupleFirstElement = sortOp.getResult(0);
1419     auto tupleSecondElement = sortOp.getResult(1);
1420 
1421     SmallVector<int64_t, 4> beginIndices(operandRank, 0);
1422     auto endIndices = llvm::to_vector<4>(operandType.getShape());
1423     endIndices.back() = std::min(static_cast<int64_t>(op.k()), lastDimSize);
1424     SmallVector<int64_t, 4> strides(operandRank, 1);
1425 
1426     // Get the slice for the top K elements.
1427     auto indicesTy = RankedTensorType::get(operandRank, rewriter.getI64Type());
1428     Value values = rewriter.create<mhlo::SliceOp>(
1429         op.getLoc(), tupleFirstElement,
1430         DenseIntElementsAttr::get(indicesTy, beginIndices),
1431         DenseIntElementsAttr::get(indicesTy, endIndices),
1432         DenseIntElementsAttr::get(indicesTy, strides));
1433     Value indices = rewriter.create<mhlo::SliceOp>(
1434         op.getLoc(), tupleSecondElement,
1435         DenseIntElementsAttr::get(indicesTy, beginIndices),
1436         DenseIntElementsAttr::get(indicesTy, endIndices),
1437         DenseIntElementsAttr::get(indicesTy, strides));
1438 
1439     rewriter.replaceOp(op, {values, indices});
1440     return success();
1441   }
1442 };
1443 
1444 struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
1445   using OpConversionPattern<ZetaOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertZetaOp1446   LogicalResult matchAndRewrite(
1447       ZetaOp op, OpAdaptor adaptor,
1448       ConversionPatternRewriter &rewriter) const override {
1449     Location loc = op.getLoc();
1450     FloatType minPrecisionTy = rewriter.getF32Type();
1451     rewriter.replaceOp(
1452         op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
1453                                   minPrecisionTy, &materializeZeta));
1454     return success();
1455   }
1456 };
1457 
1458 struct ConvertSelectOp : public OpConversionPattern<BroadcastSelectOp> {
1459   using OpConversionPattern<BroadcastSelectOp>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertSelectOp1460   LogicalResult matchAndRewrite(
1461       BroadcastSelectOp op, OpAdaptor adaptor,
1462       ConversionPatternRewriter &rewriter) const override {
1463     // Only support ranked operands.
1464     Value pred = adaptor.pred();
1465     Value onTrue = adaptor.on_true();
1466     Value onFalse = adaptor.on_false();
1467     auto predType = pred.getType().dyn_cast<RankedTensorType>();
1468     auto onTrueType = onTrue.getType().dyn_cast<RankedTensorType>();
1469     auto onFalseType = onFalse.getType().dyn_cast<RankedTensorType>();
1470     auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
1471     if (!predType || !onTrueType || !onFalseType || !resultType) {
1472       return failure();
1473     }
1474 
1475     auto loc = op.getLoc();
1476 
1477     Value predShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, pred);
1478     Value onTrueShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onTrue);
1479     Value onFalseShape = rewriter.createOrFold<shape::ShapeOfOp>(loc, onFalse);
1480     int64_t resultRank = std::max(
1481         {predType.getRank(), onTrueType.getRank(), onFalseType.getRank()});
1482 
1483     Value broadcastableCstr = rewriter.createOrFold<shape::CstrBroadcastableOp>(
1484         loc, ValueRange{predShape, onTrueShape, onFalseShape});
1485     auto assumingOp = rewriter.create<shape::AssumingOp>(
1486         loc, ArrayRef<Type>{resultType}, broadcastableCstr);
1487 
1488     OpBuilder::InsertionGuard guard(rewriter);
1489     rewriter.createBlock(&assumingOp.getDoRegion());
1490 
1491     Value resultExtents = rewriter.createOrFold<shape::BroadcastOp>(
1492         loc, shape::getExtentTensorType(op.getContext()),
1493         ValueRange{predShape, onTrueShape, onFalseShape},
1494         /*error=*/nullptr);
1495     auto shapeType =
1496         RankedTensorType::get({resultRank}, rewriter.getIndexType());
1497     resultExtents =
1498         rewriter.createOrFold<tensor::CastOp>(loc, shapeType, resultExtents);
1499 
1500     Value broadcastedPred = pred;
1501     // Pred has an implicit broadcast for scalars, so use that when convenient.
1502     if (predType.getRank() > 0) {
1503       auto predBroadcastDimensions = llvm::to_vector<4>(
1504           llvm::seq<int64_t>(resultRank - predType.getRank(), resultRank));
1505       broadcastedPred = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1506           loc,
1507           RankedTensorType::get(resultType.getShape(),
1508                                 predType.getElementType()),
1509           pred, resultExtents,
1510           rewriter.getI64TensorAttr(predBroadcastDimensions));
1511     }
1512     auto onTrueBroadcastDimensions = llvm::to_vector<4>(
1513         llvm::seq<int64_t>(resultRank - onTrueType.getRank(), resultRank));
1514     Value broadcastedOnTrue = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1515         loc,
1516         RankedTensorType::get(resultType.getShape(),
1517                               onTrueType.getElementType()),
1518         onTrue, resultExtents,
1519         rewriter.getI64TensorAttr(onTrueBroadcastDimensions));
1520     auto onFalseBroadcastDimensions = llvm::to_vector<4>(
1521         llvm::seq<int64_t>(resultRank - onFalseType.getRank(), resultRank));
1522     Value broadcastedOnFalse = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1523         loc,
1524         RankedTensorType::get(resultType.getShape(),
1525                               onFalseType.getElementType()),
1526         onFalse, resultExtents,
1527         rewriter.getI64TensorAttr(onFalseBroadcastDimensions));
1528 
1529     // And generate the final non-broadcasted ternary op.
1530     Value finalResult =
1531         rewriter.create<mhlo::SelectOp>(loc, resultType, broadcastedPred,
1532                                         broadcastedOnTrue, broadcastedOnFalse);
1533     rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
1534     rewriter.replaceOp(op, {assumingOp.getResult(0)});
1535     return success();
1536   }
1537 };
1538 
1539 // Converts binary ops that statically are determined to not broadcast directly
1540 // to the corresponding mhlo non-broadcasting op.
1541 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1542 struct ConvertTrivialNonBroadcastBinaryOp
1543     : public OpConversionPattern<ChloOpTy> {
1544   using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertTrivialNonBroadcastBinaryOp1545   LogicalResult matchAndRewrite(
1546       ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
1547       ConversionPatternRewriter &rewriter) const override {
1548     // Only rewrite for statically determinable non-broadcasting cases.
1549     auto lhsType =
1550         adaptor.lhs().getType().template dyn_cast<RankedTensorType>();
1551     auto rhsType =
1552         adaptor.rhs().getType().template dyn_cast<RankedTensorType>();
1553     if (!lhsType || !rhsType) return failure();
1554 
1555     // Requires rank broadcast.
1556     if (lhsType.getRank() != rhsType.getRank()) return failure();
1557     // Any dynamic dimension may require broadcasting and requires more
1558     // analysis.
1559     if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape())
1560       return failure();
1561 
1562     for (auto extents : llvm::zip(lhsType.getShape(), rhsType.getShape())) {
1563       auto lhsExtent = std::get<0>(extents);
1564       auto rhsExtent = std::get<1>(extents);
1565       if (lhsExtent != rhsExtent) {
1566         return failure();
1567       }
1568     }
1569 
1570     rewriter.replaceOp(op,
1571                        {Adaptor::createOp(op, op.getResult().getType(),
1572                                           adaptor.getOperands(), rewriter)});
1573     return success();
1574   }
1575 };
1576 
1577 // Converts a binary op with ranked broadcasting operands to explicitly
1578 // broadcast and invoke the corresponding mhlo non-broadcasting op.
1579 // Note that dynamic broadcasting supported by this pattern is only valid for
1580 // "numpy" broadcasting semantics as defined here:
1581 //   https://docs.scipy.org/doc/numpy/reference/ufuncs.html
1582 // Specifically, this includes the following cases:
1583 //   - Same rank broadcast (operands have the same static rank).
1584 //   - Different-rank broadcast, either without a broadcast_dims attribte or
1585 //     with the broadcast_dims attribute set to map to a prefix padding.
1586 //   - Legal combinations of degenerate (1-dim) implicit broadcasting.
1587 // The restriction on broadcast_dims derives from the definition of the
1588 // `shape.broadcast` op, which only supports prefix-padding.
1589 template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
1590 struct ConvertRankedDynamicBroadcastBinaryOp
1591     : public OpConversionPattern<ChloOpTy> {
1592   using OpConversionPattern<ChloOpTy>::OpConversionPattern;
matchAndRewritemlir::chlo::__anone87d66a00111::ConvertRankedDynamicBroadcastBinaryOp1593   LogicalResult matchAndRewrite(
1594       ChloOpTy op, typename ChloOpTy::Adaptor adaptor,
1595       ConversionPatternRewriter &rewriter) const override {
1596     // Only support ranked operands.
1597     Value lhs = adaptor.lhs();
1598     Value rhs = adaptor.rhs();
1599     auto lhsType = lhs.getType().dyn_cast<RankedTensorType>();
1600     auto rhsType = rhs.getType().dyn_cast<RankedTensorType>();
1601     auto resultType =
1602         op.getResult().getType().template dyn_cast<RankedTensorType>();
1603     if (!lhsType || !rhsType || !resultType) return failure();
1604 
1605     // Check for "numpy"-style rank broadcast.
1606     auto broadcastDimensions = op.broadcast_dimensions();
1607     if (broadcastDimensions &&
1608         !hlo::isLegalNumpyRankedBroadcast(lhs, rhs, *broadcastDimensions)) {
1609       // Note: It is unclear whether the general specification of explicit
1610       // broadcast_dimensions on binary ops is a feature we want to carry
1611       // forward. While it can technically be implemented for ranked-dynamic,
1612       // it is incompatible with unranked inputs. If this warning is emitted
1613       // in real programs, it is an indication that the feature should be
1614       // implemented versus just falling back on the more standard definition
1615       // of numpy-like prefix-padding.
1616       op.emitWarning() << "unsupported non prefix-padded dynamic rank "
1617                        << "broadcast_dimensions = " << *broadcastDimensions;
1618       return failure();
1619     }
1620 
1621     // Compute result shape.
1622     auto loc = op.getLoc();
1623 
1624     // Insert a constraint on the shapes being broadcastable and insert all
1625     // future code into an assuming block reliant on the constraint.
1626     Value lhsShape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
1627     Value rhsShape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
1628     auto broadcastableCstr =
1629         rewriter.create<shape::CstrBroadcastableOp>(loc, lhsShape, rhsShape);
1630     auto assumingOp = rewriter.create<shape::AssumingOp>(
1631         loc, ArrayRef<Type>{resultType}, broadcastableCstr.getResult());
1632 
1633     OpBuilder::InsertionGuard guard(rewriter);
1634     rewriter.createBlock(&assumingOp.getDoRegion());
1635 
1636     int64_t resultRank = std::max(lhsType.getRank(), rhsType.getRank());
1637     Value resultExtents =
1638         hlo::computeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
1639                                                                rewriter);
1640 
1641     // Note that we unconditionally emit DynamicBroadcastInDim ops and let
1642     // downstream canonicalizations fold them away if possible. This is
1643     // because, in the dynamic case, there are many corner cases regarding
1644     // when it is safe to omit, and some of them require analysis to prove
1645     // properly.
1646     auto lhsBroadcastDimensions = llvm::to_vector<4>(
1647         llvm::seq<int64_t>(resultRank - lhsType.getRank(), resultRank));
1648     Value broadcastedLhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1649         loc,
1650         RankedTensorType::get(resultType.getShape(), lhsType.getElementType()),
1651         lhs, resultExtents, rewriter.getI64TensorAttr(lhsBroadcastDimensions));
1652     auto rhsBroadcastDimensions = llvm::to_vector<4>(
1653         llvm::seq<int64_t>(resultRank - rhsType.getRank(), resultRank));
1654     Value broadcastedRhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
1655         loc,
1656         RankedTensorType::get(resultType.getShape(), rhsType.getElementType()),
1657         rhs, resultExtents, rewriter.getI64TensorAttr(rhsBroadcastDimensions));
1658 
1659     // And generate the final non-broadcasted binary op.
1660     Value finalResult = Adaptor::createOp(
1661         op, resultType, {broadcastedLhs, broadcastedRhs}, rewriter);
1662     rewriter.create<shape::AssumingYieldOp>(loc, finalResult);
1663     rewriter.replaceOp(op, {assumingOp.getResult(0)});
1664     return success();
1665   }
1666 };
1667 
1668 class ConvertDynamicReshapeOp
1669     : public OpRewritePattern<chlo::DynamicReshapeOp> {
1670  public:
1671   using OpRewritePattern::OpRewritePattern;
1672 
matchAndRewrite(chlo::DynamicReshapeOp op,PatternRewriter & rewriter) const1673   LogicalResult matchAndRewrite(chlo::DynamicReshapeOp op,
1674                                 PatternRewriter &rewriter) const override {
1675     auto loc = op.getLoc();
1676     auto tensor = op.operand();
1677     auto shape = op.output_shape();
1678 
1679     auto shapeTy = shape.getType().cast<ShapedType>();
1680     auto resultTy = op.getType().cast<ShapedType>();
1681 
1682     Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, tensor);
1683     Value numEls = rewriter.create<shape::NumElementsOp>(loc, inputShape);
1684     Value cstr = rewriter.create<mhlo::CstrReshapableOp>(loc, numEls, shape);
1685     rewriter.replaceOpWithNewOp<shape::AssumingOp>(
1686         op, cstr, [&](OpBuilder &b, Location l) {
1687           Value computedShape =
1688               b.create<mhlo::ComputeReshapeShapeOp>(l, shapeTy, numEls, shape);
1689           SmallVector<Value> result;
1690           result.push_back(b.create<mhlo::DynamicReshapeOp>(l, resultTy, tensor,
1691                                                             computedShape));
1692           return result;
1693         });
1694 
1695     return success();
1696   }
1697 };
1698 
1699 #include "generated_chlo_legalize_to_hlo.inc"
1700 }  // namespace
1701 
populateChloBroadcastingPatterns(MLIRContext * context,RewritePatternSet * patterns)1702 void populateChloBroadcastingPatterns(MLIRContext *context,
1703                                       RewritePatternSet *patterns) {
1704   // Instantiate conversion templates for conforming binary elementwise ops
1705   // that do not have different dtypes between operands and results and do
1706   // not have special attributes that need to be preserved.
1707   populateForBroadcastingBinaryOp<ConvertTrivialNonBroadcastBinaryOp>(
1708       context, patterns, 10);
1709   populateForBroadcastingBinaryOp<ConvertRankedDynamicBroadcastBinaryOp>(
1710       context, patterns, 5);
1711   patterns
1712       ->add<ConvertConstantLikeOp, ConvertDynamicReshapeOp, ConvertSelectOp>(
1713           context);
1714 }
1715 
populateDecomposeChloPatterns(MLIRContext * context,RewritePatternSet * patterns)1716 void populateDecomposeChloPatterns(MLIRContext *context,
1717                                    RewritePatternSet *patterns) {
1718   populateWithGenerated(*patterns);
1719 
1720   // Other patterns.
1721   // clang-format off
1722   patterns->add<ConvertBesselI1eOp,
1723                    ConvertCoshOp,
1724                    ConvertDigammaOp,
1725                    ConvertErfOp,
1726                    ConvertErfcOp,
1727                    ConvertLgammaOp,
1728                    ConvertNextAfterOp,
1729                    ConvertPolygammaOp,
1730                    ConvertSinhOp,
1731                    ConvertTanOp,
1732                    ConvertTopKOp,
1733                    ConvertZetaOp>(context);
1734   // clang-format on
1735 }
1736 
1737 }  // namespace chlo
1738 }  // namespace mlir
1739