1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_CANONICALIZATION_HELPER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_CANONICALIZATION_HELPER_H_
18 
19 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
20 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
21 
22 namespace mlir {
23 namespace TF {
24 
25 // Eliminate attributes that are not needed, but can get attached to Ops
26 // during import.
27 template <typename Op>
28 struct DropAttributes : public OpRewritePattern<Op> {
29   using OpRewritePattern<Op>::OpRewritePattern;
30 
31   // Drop the "output_shapes" attribute.
matchAndRewriteDropAttributes32   LogicalResult matchAndRewrite(Op op,
33                                 PatternRewriter &rewriter) const override {
34     bool found = !!op->removeAttr("output_shapes");
35     return success(found);
36   }
37 };
38 
39 // Helper function to create TF op while copying all underscore attributes from
40 // another TF op.
41 // TODO(jpienaar): This is a workaround until behavior is established.
42 template <typename OpTy, typename... Args>
CreateTfOp(RewriterBase & b,Operation * op,Args &&...args)43 OpTy CreateTfOp(RewriterBase &b, Operation *op, Args &&...args) {
44   auto ret = b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
45   CopyDeviceAndUnderscoredAttributes(op, ret.getOperation());
46   return ret;
47 }
48 
49 // Helper function to replace TF op with another op while copying all underscore
50 // attributes from the TF op.
51 // TODO(jpienaar): This is a workaround until behavior is established.
52 template <typename OpTy, typename... Args>
ReplaceTfOpWithNewOp(RewriterBase & b,Operation * op,Args &&...args)53 OpTy ReplaceTfOpWithNewOp(RewriterBase &b, Operation *op, Args &&...args) {
54   auto ret = CreateTfOp<OpTy>(b, op, std::forward<Args>(args)...);
55   b.replaceOp(op, ret.getOperation()->getResults());
56   return ret;
57 }
58 
59 }  // namespace TF
60 }  // namespace mlir
61 
62 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_CANONICALIZATION_HELPER_H_
63