1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Dialect/Math/IR/Math.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
23
24 namespace tensorflow {
25 namespace {
26
27 #define GEN_PASS_CLASSES
28 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc"
29
30 using ::llvm::ArrayRef;
31 using ::llvm::SmallVector;
32
33 using ::mlir::ImplicitLocOpBuilder;
34 using ::mlir::LogicalResult;
35 using ::mlir::OpRewritePattern;
36 using ::mlir::PatternRewriter;
37 using ::mlir::RewritePatternSet;
38 using ::mlir::Type;
39 using ::mlir::Value;
40 using ::mlir::VectorType;
41
42 namespace arith = ::mlir::arith;
43 namespace math = ::mlir::math;
44 namespace vector = ::mlir::vector;
45
46 using TypePredicate = ::llvm::function_ref<bool(Type)>;
47
48 // Returns vector shape if the element type is matching the predicate (scalars
49 // that do match the predicate have shape equal to `{1}`).
vectorShape(Type type,TypePredicate pred)50 static llvm::Optional<SmallVector<int64_t, 2>> vectorShape(Type type,
51 TypePredicate pred) {
52 // If the type matches the predicate then its shape is `{1}`.
53 if (pred(type)) return SmallVector<int64_t, 2>{1};
54
55 // Otherwise check if the type is a vector type.
56 auto vectorType = type.dyn_cast<VectorType>();
57 if (vectorType && pred(vectorType.getElementType())) {
58 return llvm::to_vector<2>(vectorType.getShape());
59 }
60
61 return llvm::None;
62 }
63
64 // Returns vector shape of the type. If the type is a scalar returns `1`.
vectorShape(Type type)65 static SmallVector<int64_t, 2> vectorShape(Type type) {
66 auto vectorType = type.dyn_cast<VectorType>();
67 return vectorType ? llvm::to_vector<2>(vectorType.getShape())
68 : SmallVector<int64_t, 2>{1};
69 }
70
71 // Returns vector element type. If the type is a scalar returns the argument.
elementType(Type type)72 static Type elementType(Type type) {
73 auto vectorType = type.dyn_cast<VectorType>();
74 return vectorType ? vectorType.getElementType() : type;
75 }
76
isF32(Type type)77 static bool isF32(Type type) { return type.isF32(); }
78
79 //----------------------------------------------------------------------------//
80 // Broadcast scalar types and values into vector types and values.
81 //----------------------------------------------------------------------------//
82
83 // Returns true if shape != {1}.
isNonScalarShape(ArrayRef<int64_t> shape)84 static bool isNonScalarShape(ArrayRef<int64_t> shape) {
85 return shape.size() > 1 || shape[0] > 1;
86 }
87
88 // Broadcasts scalar type into vector type (iff shape is non-scalar).
broadcast(Type type,ArrayRef<int64_t> shape)89 static Type broadcast(Type type, ArrayRef<int64_t> shape) {
90 assert(!type.isa<VectorType>() && "must be scalar type");
91 return isNonScalarShape(shape) ? VectorType::get(shape, type) : type;
92 }
93
94 // Broadcasts scalar value into vector (iff shape is non-scalar).
broadcast(ImplicitLocOpBuilder & builder,Value value,ArrayRef<int64_t> shape)95 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
96 ArrayRef<int64_t> shape) {
97 assert(!value.getType().isa<VectorType>() && "must be scalar value");
98 auto type = broadcast(value.getType(), shape);
99 return isNonScalarShape(shape)
100 ? builder.create<vector::BroadcastOp>(type, value)
101 : value;
102 }
103
104 //----------------------------------------------------------------------------//
105 // Helper functions to create constants.
106 //----------------------------------------------------------------------------//
107
f32Cst(ImplicitLocOpBuilder & builder,float value)108 static Value f32Cst(ImplicitLocOpBuilder &builder, float value) {
109 return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
110 }
111
i32Cst(ImplicitLocOpBuilder & builder,int32_t value)112 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
113 return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
114 }
115
116 //----------------------------------------------------------------------------//
117 // Helper functions to build math function approximations.
118 //----------------------------------------------------------------------------//
119
min(ImplicitLocOpBuilder & builder,Value a,Value b)120 static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) {
121 return builder.create<mlir::arith::SelectOp>(
122 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, b), a, b);
123 }
124
max(ImplicitLocOpBuilder & builder,Value a,Value b)125 static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) {
126 return builder.create<mlir::arith::SelectOp>(
127 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, b), a, b);
128 }
129
clamp(ImplicitLocOpBuilder & builder,Value value,Value lowerBound,Value upperBound)130 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
131 Value upperBound) {
132 return max(builder, min(builder, value, upperBound), lowerBound);
133 }
134
135 // Eigen's implementation of ldexp.
136 // ldexp(x, exp) = x * 2^exp
137 // Set e = min(max(exp, -278), 278)
138 // b = floor(e/4)
139 // Then out = ((((x * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
ldexp(ImplicitLocOpBuilder & builder,Value x,Value exp)140 static Value ldexp(ImplicitLocOpBuilder &builder, Value x, Value exp) {
141 assert(isF32(elementType(x.getType())) && "argument x must be f32 type");
142 assert(isF32(elementType(exp.getType())) && "argument exp must be f32 type");
143
144 auto shape = vectorShape(x.getType());
145 auto exp_shape = vectorShape(exp.getType());
146 assert(shape == exp_shape && "x and exp must be of equal shape");
147 auto f32Vec = broadcast(builder.getF32Type(), shape);
148 auto i32Vec = broadcast(builder.getI32Type(), shape);
149
150 auto bcast = [&](Value value) -> Value {
151 return broadcast(builder, value, shape);
152 };
153 auto mulf = [&](Value a, Value b) -> Value {
154 return builder.create<arith::MulFOp>(a, b);
155 };
156 auto subi = [&](Value a, Value b) -> Value {
157 return builder.create<arith::SubIOp>(a, b);
158 };
159 auto shli = [&](Value a, Value pos) -> Value {
160 return builder.create<arith::ShLIOp>(a, pos);
161 };
162
163 Value cstMantBitsI = bcast(i32Cst(builder, 23));
164 Value cstMaxExponent = bcast(f32Cst(builder, 278.0f));
165 Value cstMinExponent = bcast(f32Cst(builder, -278.0f));
166 Value cstBiasI = bcast(i32Cst(builder, 127));
167 Value cst2I = bcast(i32Cst(builder, 2));
168
169 Value e = clamp(builder, exp, cstMinExponent, cstMaxExponent);
170 Value eI = builder.create<arith::FPToSIOp>(i32Vec, e);
171 Value bI = builder.create<arith::ShRSIOp>(eI, cst2I);
172 Value biasedBI = builder.create<arith::AddIOp>(bI, cstBiasI);
173 Value c = builder.create<arith::BitcastOp>(
174 f32Vec, shli(biasedBI, cstMantBitsI)); // 2^b
175 Value out = mulf(mulf(mulf(x, c), c), c); // x * 2^(3b)
176 bI = subi(subi(subi(eI, bI), bI), bI); // e - 3b
177 biasedBI = builder.create<arith::AddIOp>(bI, cstBiasI); // 2^(e - 3b)
178 c = builder.create<arith::BitcastOp>(f32Vec, shli(biasedBI, cstMantBitsI));
179 out = mulf(out, c);
180 return out;
181 }
182
183 struct EigenExpApproximation : public OpRewritePattern<math::ExpOp> {
184 public:
185 using OpRewritePattern::OpRewritePattern;
186
187 LogicalResult matchAndRewrite(math::ExpOp op,
188 PatternRewriter &rewriter) const final;
189 };
190
matchAndRewrite(math::ExpOp op,PatternRewriter & rewriter) const191 LogicalResult EigenExpApproximation::matchAndRewrite(
192 math::ExpOp op, PatternRewriter &rewriter) const {
193 auto shape = vectorShape(op.getOperand().getType(), isF32);
194 if (!shape.has_value())
195 return rewriter.notifyMatchFailure(op, "unsupported operand type");
196 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
197
198 auto addf = [&](Value a, Value b) -> Value {
199 return builder.create<arith::AddFOp>(a, b);
200 };
201 auto bcast = [&](Value value) -> Value {
202 return broadcast(builder, value, *shape);
203 };
204 auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
205 auto fma = [&](Value a, Value b, Value c) {
206 return builder.create<math::FmaOp>(a, b, c);
207 };
208 auto mulf = [&](Value a, Value b) -> Value {
209 return builder.create<arith::MulFOp>(a, b);
210 };
211
212 Value cstOne = bcast(f32Cst(builder, 1.0f));
213 Value cstHalf = bcast(f32Cst(builder, 0.5f));
214 Value cstExpHi = bcast(f32Cst(builder, 88.723f));
215 Value cstExpLo = bcast(f32Cst(builder, -88.723f));
216
217 Value cstCephesLog2E = bcast(f32Cst(builder, 1.44269504088896341f));
218 Value cstCephesExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f));
219 Value cstCephesExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f));
220 Value cstCephesExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f));
221 Value cstCephesExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f));
222 Value cstCephesExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f));
223 Value cstCephesExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f));
224
225 Value x = clamp(builder, op.getOperand(), cstExpLo, cstExpHi);
226 Value m = floor(fma(x, cstCephesLog2E, cstHalf));
227
228 Value cstCephesExpC1 = bcast(f32Cst(builder, -0.693359375f));
229 Value cstCephesExpC2 = bcast(f32Cst(builder, 2.12194440e-4f));
230 Value r = fma(m, cstCephesExpC1, x);
231 r = fma(m, cstCephesExpC2, r);
232
233 Value r2 = mulf(r, r);
234 Value r3 = mulf(r2, r);
235
236 Value y = fma(cstCephesExpP0, r, cstCephesExpP1);
237 Value y1 = fma(cstCephesExpP3, r, cstCephesExpP4);
238 Value y2 = addf(r, cstOne);
239 y = fma(y, r, cstCephesExpP2);
240 y1 = fma(y1, r, cstCephesExpP5);
241 y = fma(y, r3, y1);
242 y = fma(y, r2, y2);
243 Value ret = max(builder, ldexp(builder, y, m), op.getOperand());
244 rewriter.replaceOp(op, ret);
245 return mlir::success();
246 }
247
248 struct EigenExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
249 public:
250 using OpRewritePattern::OpRewritePattern;
251
252 LogicalResult matchAndRewrite(math::ExpM1Op op,
253 PatternRewriter &rewriter) const final;
254 };
255
matchAndRewrite(math::ExpM1Op op,PatternRewriter & rewriter) const256 LogicalResult EigenExpM1Approximation::matchAndRewrite(
257 math::ExpM1Op op, PatternRewriter &rewriter) const {
258 auto shape = vectorShape(op.getOperand().getType(), isF32);
259 if (!shape.hasValue())
260 return rewriter.notifyMatchFailure(op, "unsupported operand type");
261
262 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
263 auto bcast = [&](Value value) -> Value {
264 return broadcast(builder, value, *shape);
265 };
266
267 // expm1(x) = exp(x) - 1 = u - 1.
268 // We have to handle it carefully when x is near 0, i.e. u ~= 1,
269 // and when the input is ~= -inf, i.e. u - 1 ~= -1.
270 Value cstOne = bcast(f32Cst(builder, 1.0f));
271 Value cstNegOne = bcast(f32Cst(builder, -1.0f));
272 Value x = op.getOperand();
273 Value u = builder.create<math::ExpOp>(x);
274 Value uEqOneOrNaN =
275 builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
276 Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
277 Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
278 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
279 // logU = log(u) ~= x
280 Value logU = builder.create<math::LogOp>(u);
281
282 // Detect exp(x) = +inf; written this way to avoid having to form +inf.
283 Value isInf =
284 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
285
286 // (u - 1) * (x / ~x)
287 Value expm1 = builder.create<arith::MulFOp>(
288 uMinusOne, builder.create<arith::DivFOp>(x, logU));
289 expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
290 Value approximation = builder.create<arith::SelectOp>(
291 uEqOneOrNaN, x,
292 builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
293 rewriter.replaceOp(op, approximation);
294
295 return mlir::success();
296 }
297
populateMathApproximationPatterns(RewritePatternSet & patterns,ArrayRef<std::string> oplist)298 static void populateMathApproximationPatterns(RewritePatternSet &patterns,
299 ArrayRef<std::string> oplist) {
300 for (const std::string &op : oplist) {
301 if (op == "all") {
302 patterns.add<EigenExpApproximation, EigenExpM1Approximation>(
303 patterns.getContext());
304 } else if (op == "exp") {
305 patterns.add<EigenExpApproximation>(patterns.getContext());
306 } else if (op == "expm1") {
307 patterns.add<EigenExpM1Approximation>(patterns.getContext());
308 }
309 }
310 }
311
312 struct MathApproximationPass
313 : public MathApproximationBase<MathApproximationPass> {
MathApproximationPasstensorflow::__anon7c5bc2c80111::MathApproximationPass314 explicit MathApproximationPass(ArrayRef<std::string> approx_oplist) {
315 this->oplist = approx_oplist;
316 }
317
318 void runOnOperation() override;
319 };
320
runOnOperation()321 void MathApproximationPass::runOnOperation() {
322 mlir::RewritePatternSet patterns(&getContext());
323 populateMathApproximationPatterns(patterns, oplist);
324 if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
325 std::move(patterns))))
326 signalPassFailure();
327 }
328
329 } // namespace
330
331 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateMathApproximationPass(ArrayRef<std::string> oplist)332 CreateMathApproximationPass(ArrayRef<std::string> oplist) {
333 return std::make_unique<MathApproximationPass>(oplist);
334 }
335
336 } // namespace tensorflow
337