xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfr/utils/utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
18 
19 #include "mlir/IR/Block.h"  // from @llvm-project
20 #include "mlir/IR/Operation.h"  // from @llvm-project
21 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
22 #include "mlir/Support/LLVM.h"  // from @llvm-project
23 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
25 
26 namespace mlir {
27 namespace TFR {
28 
29 // This is a hardcoded rule for mapping a TF op name to the corresponding
30 // TFR function name. Examples:
31 //   tf.Pack => tf__pack
32 //   tf.ConcatV2 => tf__concat_v2
33 // TODO(fengliuai): move to an util file.
34 std::string GetComposeFuncName(StringRef tf_op_name);
35 
36 // This is a hardcoded rule for mapping a TFR function op name to the
37 // corresponding TF opname. Examples:
38 //   tf__pack -> tf.Pack
39 //   tf__concat_v2 => tf.ConcatV2
40 std::string GetTFOpName(StringRef compose_func_name);
41 
42 // Validate the attributes of 'src' is either contained in the registered
43 // attribute sets or in the allowed list.
44 LogicalResult ValidateAttrs(Operation* src, const StringSet<>& registered);
45 
46 // Copies all the allowed attributes in 'src' to 'dst'. The copy failed if the
47 // 'dst' has the attribute. Return a failure if there are any attributes are not
48 // allowed and also unregistered.
49 LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst,
50                                            const StringSet<>& registered);
51 
52 // Copies all the allowed attributes in 'src' to 'dst'. FlatSymbolRefAttr is
53 // excluded.
54 LogicalResult CopyNonSymbolRefAttrs(CallOp src, Operation* dst);
55 
56 // Propagates all the attributes in 'src' to the operations between 'begin' and
57 // 'end'. Operation 'end' is excluded.
58 void PropagateAttrsToOperations(CallOp src, Block::iterator begin,
59                                 Block::iterator end);
60 
61 }  // namespace TFR
62 }  // namespace mlir
63 
64 #endif  // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_
65