1// Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include "mlir/IR/OpBase.td" 16#include "tensorflow/core/transforms/utils/pdll/utils.pdll" 17 18Constraint AttrIsF32OrBF16(attr: Attr) [{ 19 TypeAttr type_attr = attr.dyn_cast<TypeAttr>(); 20 if (!type_attr) return failure(); 21 return success(type_attr.getValue().isa<Float32Type>() || 22 type_attr.getValue().isa<BFloat16Type>()); 23}]; 24 25Rewrite ReplaceMulWith_MklSwish(op: Op, arg: Value, controls: ValueRange) -> Op [{ 26 SmallVector<Value> operands; 27 operands.push_back(arg); 28 llvm::append_range(operands, controls); 29 Operation *new_op = rewriter.create(op->getLoc(), 30 rewriter.getStringAttr("tfg._MklSwish"), 31 operands, 32 op->getResultTypes(), 33 op->getAttrs()); 34 return new_op; 35}]; 36 37// Match op with form Mul(Sigmoid(x), x) 38Pattern SigmoidAndMul0 { 39 // Note that `_` is used to match the control operands. 40 let sigmoid_arg = op<tfg.Sigmoid>(arg: Value); 41 let root = op<tfg.Mul>(sigmoid_arg.0, arg, controls: ValueRange) { "T" = _: AttrIsF32OrBF16}; 42 OpHasCpuDevice(root); 43 replace root with ReplaceMulWith_MklSwish(root, arg, controls); 44} 45 46// Match op with form Mul(x, Sigmoid(x)) 47Pattern SigmoidAndMul1 { 48 // Note that `_` is used to match the control operands. 49 let sigmoid_arg = op<tfg.Sigmoid>(arg: Value); 50 let root = op<tfg.Mul>(arg, sigmoid_arg.0, controls: ValueRange) { "T" = x: AttrIsF32OrBF16}; 51 OpHasCpuDevice(root); 52 replace root with ReplaceMulWith_MklSwish(root, arg, controls); 53} 54