1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
23 
24 namespace tensorflow {
25 namespace {
26 
27 #define GEN_PASS_CLASSES
28 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
29 
30 using ::llvm::ArrayRef;
31 using ::llvm::SmallVector;
32 
33 using ::mlir::ImplicitLocOpBuilder;
34 using ::mlir::LogicalResult;
35 using ::mlir::OpRewritePattern;
36 using ::mlir::PatternRewriter;
37 using ::mlir::RewritePatternSet;
38 using ::mlir::Type;
39 using ::mlir::Value;
40 using ::mlir::VectorType;
41 
42 namespace arith = ::mlir::arith;
43 namespace math = ::mlir::math;
44 namespace vector = ::mlir::vector;
45 
46 using TypePredicate = ::llvm::function_ref<bool(Type)>;
47 
48 // Returns vector shape if the element type is matching the predicate (scalars
49 // that do match the predicate have shape equal to `{1}`).
vectorShape(Type type,TypePredicate pred)50 static llvm::Optional<SmallVector<int64_t, 2>> vectorShape(Type type,
51                                                            TypePredicate pred) {
52   // If the type matches the predicate then its shape is `{1}`.
53   if (pred(type)) return SmallVector<int64_t, 2>{1};
54 
55   // Otherwise check if the type is a vector type.
56   auto vectorType = type.dyn_cast<VectorType>();
57   if (vectorType && pred(vectorType.getElementType())) {
58     return llvm::to_vector<2>(vectorType.getShape());
59   }
60 
61   return llvm::None;
62 }
63 
64 // Returns vector shape of the type. If the type is a scalar returns `1`.
vectorShape(Type type)65 static SmallVector<int64_t, 2> vectorShape(Type type) {
66   auto vectorType = type.dyn_cast<VectorType>();
67   return vectorType ? llvm::to_vector<2>(vectorType.getShape())
68                     : SmallVector<int64_t, 2>{1};
69 }
70 
71 // Returns vector element type. If the type is a scalar returns the argument.
elementType(Type type)72 static Type elementType(Type type) {
73   auto vectorType = type.dyn_cast<VectorType>();
74   return vectorType ? vectorType.getElementType() : type;
75 }
76 
isF32(Type type)77 static bool isF32(Type type) { return type.isF32(); }
78 
79 //----------------------------------------------------------------------------//
80 // Broadcast scalar types and values into vector types and values.
81 //----------------------------------------------------------------------------//
82 
83 // Returns true if shape != {1}.
isNonScalarShape(ArrayRef<int64_t> shape)84 static bool isNonScalarShape(ArrayRef<int64_t> shape) {
85   return shape.size() > 1 || shape[0] > 1;
86 }
87 
88 // Broadcasts scalar type into vector type (iff shape is non-scalar).
broadcast(Type type,ArrayRef<int64_t> shape)89 static Type broadcast(Type type, ArrayRef<int64_t> shape) {
90   assert(!type.isa<VectorType>() && "must be scalar type");
91   return isNonScalarShape(shape) ? VectorType::get(shape, type) : type;
92 }
93 
94 // Broadcasts scalar value into vector (iff shape is non-scalar).
broadcast(ImplicitLocOpBuilder & builder,Value value,ArrayRef<int64_t> shape)95 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
96                        ArrayRef<int64_t> shape) {
97   assert(!value.getType().isa<VectorType>() && "must be scalar value");
98   auto type = broadcast(value.getType(), shape);
99   return isNonScalarShape(shape)
100              ? builder.create<vector::BroadcastOp>(type, value)
101              : value;
102 }
103 
104 //----------------------------------------------------------------------------//
105 // Helper functions to create constants.
106 //----------------------------------------------------------------------------//
107 
f32Cst(ImplicitLocOpBuilder & builder,float value)108 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
109   return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
110 }
111 
i32Cst(ImplicitLocOpBuilder & builder,int32_t value)112 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
113   return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
114 }
115 
116 //----------------------------------------------------------------------------//
117 // Helper functions to build math function approximations.
118 //----------------------------------------------------------------------------//
119 
min(ImplicitLocOpBuilder & builder,Value a,Value b)120 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) {
121   return builder.create<mlir::arith::SelectOp>(
122       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b);
123 }
124 
max(ImplicitLocOpBuilder & builder,Value a,Value b)125 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) {
126   return builder.create<mlir::arith::SelectOp>(
127       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b);
128 }
129 
clamp(ImplicitLocOpBuilder & builder,Value value,Value lowerBound,Value upperBound)130 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
131                    Value upperBound) {
132   return max(builder, min(builder, value, upperBound), lowerBound);
133 }
134 
135 // Eigen's implementation of ldexp.
136 // ldexp(x, exp) = x * 2^exp
137 // Set e = min(max(exp, -278), 278)
138 //     b = floor(e/4)
139 // Then out = ((((x * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
ldexp(ImplicitLocOpBuilder & builder,Value x,Value exp)140 static Value ldexp(ImplicitLocOpBuilder &builder, Value x, Value exp) {
141   assert(isF32(elementType(x.getType())) && "argument x must be f32 type");
142   assert(isF32(elementType(exp.getType())) && "argument exp must be f32 type");
143 
144   auto shape = vectorShape(x.getType());
145   auto exp_shape = vectorShape(exp.getType());
146   assert(shape == exp_shape && "x and exp must be of equal shape");
147   auto f32Vec = broadcast(builder.getF32Type(), shape);
148   auto i32Vec = broadcast(builder.getI32Type(), shape);
149 
150   auto bcast = [&](Value value) -> Value {
151     return broadcast(builder, value, shape);
152   };
153   auto mulf = [&](Value a, Value b) -> Value {
154     return builder.create<arith::MulFOp>(a, b);
155   };
156   auto subi = [&](Value a, Value b) -> Value {
157     return builder.create<arith::SubIOp>(a, b);
158   };
159   auto shli = [&](Value a, Value pos) -> Value {
160     return builder.create<arith::ShLIOp>(a, pos);
161   };
162 
163   Value cstMantBitsI = bcast(i32Cst(builder, 23));
164   Value cstMaxExponent = bcast(f32Cst(builder, 278.0f));
165   Value cstMinExponent = bcast(f32Cst(builder, -278.0f));
166   Value cstBiasI = bcast(i32Cst(builder, 127));
167   Value cst2I = bcast(i32Cst(builder, 2));
168 
169   Value e = clamp(builder, exp, cstMinExponent, cstMaxExponent);
170   Value eI = builder.create<arith::FPToSIOp>(i32Vec, e);
171   Value bI = builder.create<arith::ShRSIOp>(eI, cst2I);
172   Value biasedBI = builder.create<arith::AddIOp>(bI, cstBiasI);
173   Value c = builder.create<arith::BitcastOp>(
174       f32Vec, shli(biasedBI, cstMantBitsI));               // 2^b
175   Value out = mulf(mulf(mulf(x, c), c), c);                // x * 2^(3b)
176   bI = subi(subi(subi(eI, bI), bI), bI);                   // e - 3b
177   biasedBI = builder.create<arith::AddIOp>(bI, cstBiasI);  // 2^(e - 3b)
178   c = builder.create<arith::BitcastOp>(f32Vec, shli(biasedBI, cstMantBitsI));
179   out = mulf(out, c);
180   return out;
181 }
182 
183 struct EigenExpApproximation : public OpRewritePattern<math::ExpOp> {
184  public:
185   using OpRewritePattern::OpRewritePattern;
186 
187   LogicalResult matchAndRewrite(math::ExpOp op,
188                                 PatternRewriter &rewriter) const final;
189 };
190 
matchAndRewrite(math::ExpOp op,PatternRewriter & rewriter) const191 LogicalResult EigenExpApproximation::matchAndRewrite(
192     math::ExpOp op, PatternRewriter &rewriter) const {
193   auto shape = vectorShape(op.getOperand().getType(), isF32);
194   if (!shape.has_value())
195     return rewriter.notifyMatchFailure(op, "unsupported operand type");
196   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
197 
198   auto addf = [&](Value a, Value b) -> Value {
199     return builder.create<arith::AddFOp>(a, b);
200   };
201   auto bcast = [&](Value value) -> Value {
202     return broadcast(builder, value, *shape);
203   };
204   auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
205   auto fma = [&](Value a, Value b, Value c) {
206     return builder.create<math::FmaOp>(a, b, c);
207   };
208   auto mulf = [&](Value a, Value b) -> Value {
209     return builder.create<arith::MulFOp>(a, b);
210   };
211 
212   Value cstOne = bcast(f32Cst(builder, 1.0f));
213   Value cstHalf = bcast(f32Cst(builder, 0.5f));
214   Value cstExpHi = bcast(f32Cst(builder, 88.723f));
215   Value cstExpLo = bcast(f32Cst(builder, -88.723f));
216 
217   Value cstCephesLog2E = bcast(f32Cst(builder, 1.44269504088896341f));
218   Value cstCephesExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f));
219   Value cstCephesExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f));
220   Value cstCephesExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f));
221   Value cstCephesExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f));
222   Value cstCephesExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f));
223   Value cstCephesExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f));
224 
225   Value x = clamp(builder, op.getOperand(), cstExpLo, cstExpHi);
226   Value m = floor(fma(x, cstCephesLog2E, cstHalf));
227 
228   Value cstCephesExpC1 = bcast(f32Cst(builder, -0.693359375f));
229   Value cstCephesExpC2 = bcast(f32Cst(builder, 2.12194440e-4f));
230   Value r = fma(m, cstCephesExpC1, x);
231   r = fma(m, cstCephesExpC2, r);
232 
233   Value r2 = mulf(r, r);
234   Value r3 = mulf(r2, r);
235 
236   Value y = fma(cstCephesExpP0, r, cstCephesExpP1);
237   Value y1 = fma(cstCephesExpP3, r, cstCephesExpP4);
238   Value y2 = addf(r, cstOne);
239   y = fma(y, r, cstCephesExpP2);
240   y1 = fma(y1, r, cstCephesExpP5);
241   y = fma(y, r3, y1);
242   y = fma(y, r2, y2);
243   Value ret = max(builder, ldexp(builder, y, m), op.getOperand());
244   rewriter.replaceOp(op, ret);
245   return mlir::success();
246 }
247 
248 struct EigenExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
249  public:
250   using OpRewritePattern::OpRewritePattern;
251 
252   LogicalResult matchAndRewrite(math::ExpM1Op op,
253                                 PatternRewriter &rewriter) const final;
254 };
255 
matchAndRewrite(math::ExpM1Op op,PatternRewriter & rewriter) const256 LogicalResult EigenExpM1Approximation::matchAndRewrite(
257     math::ExpM1Op op, PatternRewriter &rewriter) const {
258   auto shape = vectorShape(op.getOperand().getType(), isF32);
259   if (!shape.hasValue())
260     return rewriter.notifyMatchFailure(op, "unsupported operand type");
261 
262   ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
263   auto bcast = [&](Value value) -> Value {
264     return broadcast(builder, value, *shape);
265   };
266 
267   // expm1(x) = exp(x) - 1 = u - 1.
268   // We have to handle it carefully when x is near 0, i.e. u ~= 1,
269   // and when the input is ~= -inf, i.e. u - 1 ~= -1.
270   Value cstOne = bcast(f32Cst(builder, 1.0f));
271   Value cstNegOne = bcast(f32Cst(builder, -1.0f));
272   Value x = op.getOperand();
273   Value u = builder.create<math::ExpOp>(x);
274   Value uEqOneOrNaN =
275       builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
276   Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
277   Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
278       arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
279   // logU = log(u) ~= x
280   Value logU = builder.create<math::LogOp>(u);
281 
282   // Detect exp(x) = +inf; written this way to avoid having to form +inf.
283   Value isInf =
284       builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
285 
286   // (u - 1) * (x / ~x)
287   Value expm1 = builder.create<arith::MulFOp>(
288       uMinusOne, builder.create<arith::DivFOp>(x, logU));
289   expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
290   Value approximation = builder.create<arith::SelectOp>(
291       uEqOneOrNaN, x,
292       builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
293   rewriter.replaceOp(op, approximation);
294 
295   return mlir::success();
296 }
297 
populateMathApproximationPatterns(RewritePatternSet & patterns,ArrayRef<std::string> oplist)298 static void populateMathApproximationPatterns(RewritePatternSet &patterns,
299                                               ArrayRef<std::string> oplist) {
300   for (const std::string &op : oplist) {
301     if (op == "all") {
302       patterns.add<EigenExpApproximation, EigenExpM1Approximation>(
303           patterns.getContext());
304     } else if (op == "exp") {
305       patterns.add<EigenExpApproximation>(patterns.getContext());
306     } else if (op == "expm1") {
307       patterns.add<EigenExpM1Approximation>(patterns.getContext());
308     }
309   }
310 }
311 
312 struct MathApproximationPass
313     : public MathApproximationBase<MathApproximationPass> {
MathApproximationPasstensorflow::__anon7c5bc2c80111::MathApproximationPass314   explicit MathApproximationPass(ArrayRef<std::string> approx_oplist) {
315     this->oplist = approx_oplist;
316   }
317 
318   void runOnOperation() override;
319 };
320 
runOnOperation()321 void MathApproximationPass::runOnOperation() {
322   mlir::RewritePatternSet patterns(&getContext());
323   populateMathApproximationPatterns(patterns, oplist);
324   if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
325                                                 std::move(patterns))))
326     signalPassFailure();
327 }
328 
329 }  // namespace
330 
331 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateMathApproximationPass(ArrayRef<std::string> oplist)332 CreateMathApproximationPass(ArrayRef<std::string> oplist) {
333   return std::make_unique<MathApproximationPass>(oplist);
334 }
335 
336 }  // namespace tensorflow
337