xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/remapper/pdll/mkl_patterns.pdll (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1// Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "mlir/IR/OpBase.td"
16#include "tensorflow/core/transforms/utils/pdll/utils.pdll"
17
18Constraint AttrIsF32OrBF16(attr: Attr) [{
19  TypeAttr type_attr = attr.dyn_cast<TypeAttr>();
20  if (!type_attr) return failure();
21  return success(type_attr.getValue().isa<Float32Type>() ||
22                 type_attr.getValue().isa<BFloat16Type>());
23}];
24
25Rewrite ReplaceMulWith_MklSwish(op: Op, arg: Value, controls: ValueRange) -> Op [{
26  SmallVector<Value> operands;
27  operands.push_back(arg);
28  llvm::append_range(operands, controls);
29  Operation *new_op = rewriter.create(op->getLoc(),
30                                      rewriter.getStringAttr("tfg._MklSwish"),
31                                      operands,
32                                      op->getResultTypes(),
33                                      op->getAttrs());
34  return new_op;
35}];
36
37// Match op with form Mul(Sigmoid(x), x)
38Pattern SigmoidAndMul0 {
39  // Note that `_` is used to match the control operands.
40  let sigmoid_arg = op<tfg.Sigmoid>(arg: Value);
41  let root = op<tfg.Mul>(sigmoid_arg.0, arg, controls: ValueRange) { "T" = _: AttrIsF32OrBF16};
42  OpHasCpuDevice(root);
43  replace root with ReplaceMulWith_MklSwish(root, arg, controls);
44}
45
46// Match op with form Mul(x, Sigmoid(x))
47Pattern SigmoidAndMul1 {
48  // Note that `_` is used to match the control operands.
49  let sigmoid_arg = op<tfg.Sigmoid>(arg: Value);
50  let root = op<tfg.Mul>(arg, sigmoid_arg.0, controls: ValueRange) { "T" = x: AttrIsF32OrBF16};
51  OpHasCpuDevice(root);
52  replace root with ReplaceMulWith_MklSwish(root, arg, controls);
53}
54