1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_ 18 19 #include "mlir/IR/Block.h" // from @llvm-project 20 #include "mlir/IR/Operation.h" // from @llvm-project 21 #include "mlir/IR/OperationSupport.h" // from @llvm-project 22 #include "mlir/Support/LLVM.h" // from @llvm-project 23 #include "mlir/Support/LogicalResult.h" // from @llvm-project 24 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" 25 26 namespace mlir { 27 namespace TFR { 28 29 // This is a hardcoded rule for mapping a TF op name to the corresponding 30 // TFR function name. Examples: 31 // tf.Pack => tf__pack 32 // tf.ConcatV2 => tf__concat_v2 33 // TODO(fengliuai): move to an util file. 34 std::string GetComposeFuncName(StringRef tf_op_name); 35 36 // This is a hardcoded rule for mapping a TFR function op name to the 37 // corresponding TF opname. Examples: 38 // tf__pack -> tf.Pack 39 // tf__concat_v2 => tf.ConcatV2 40 std::string GetTFOpName(StringRef compose_func_name); 41 42 // Validate the attributes of 'src' is either contained in the registered 43 // attribute sets or in the allowed list. 44 LogicalResult ValidateAttrs(Operation* src, const StringSet<>& registered); 45 46 // Copies all the allowed attributes in 'src' to 'dst'. The copy failed if the 47 // 'dst' has the attribute. Return a failure if there are any attributes are not 48 // allowed and also unregistered. 49 LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst, 50 const StringSet<>& registered); 51 52 // Copies all the allowed attributes in 'src' to 'dst'. FlatSymbolRefAttr is 53 // excluded. 54 LogicalResult CopyNonSymbolRefAttrs(CallOp src, Operation* dst); 55 56 // Propagates all the attributes in 'src' to the operations between 'begin' and 57 // 'end'. Operation 'end' is excluded. 58 void PropagateAttrsToOperations(CallOp src, Block::iterator begin, 59 Block::iterator end); 60 61 } // namespace TFR 62 } // namespace mlir 63 64 #endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_UTILS_UTILS_H_ 65