1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file defines the op traits used in the MLIR TensorFlow dialect.
17
18 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
19 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
20
21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22 #include "mlir/IR/OpDefinition.h" // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
24 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
25 #include "mlir/Support/LogicalResult.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
28
29 namespace mlir {
30 namespace OpTrait {
31 namespace TF {
32
33 // Verifies if 'ref_type' is a REF type corresponding to 'type'.
VerifyRefTypeMatch(mlir::Type type,mlir::Type maybe_ref_type)34 static inline LogicalResult VerifyRefTypeMatch(mlir::Type type,
35 mlir::Type maybe_ref_type) {
36 if (auto ref_type =
37 maybe_ref_type.dyn_cast<mlir::tf_type::TensorFlowRefType>())
38 return success(ref_type.RemoveRef().getTypeID() == type.getTypeID());
39 return failure();
40 }
41
42 // This class provides verification for ops that are known to have the same
43 // result types and all operands are either of the same type as result or a REF
44 // type corresponding to the result type.
45 // TODO(jpienaar): Update the name and the description.
46 template <typename ConcreteType>
47 class OperandsSameAsResultsTypeOrRef
48 : public TraitBase<ConcreteType, OperandsSameAsResultsTypeOrRef> {
49 public:
verifyTrait(Operation * op)50 static LogicalResult verifyTrait(Operation* op) {
51 LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
52 if (failed(shapeMatch)) return shapeMatch;
53 Type type = op->getResult(0).getType();
54 // Verify that the first result type is same as the rest of the results.
55 // We skip the comparison against itself.
56 for (auto result_type : llvm::drop_begin(op->getResultTypes(), 1)) {
57 if (!mlir::tf_type::HasCompatibleElementTypes(type, result_type))
58 return op->emitOpError()
59 << "requires all return types to have compatible element types";
60 }
61 for (auto operand_type : op->getOperandTypes()) {
62 if (!mlir::tf_type::HasCompatibleElementTypes(
63 operand_type, type, /*may_ignore_ref_type_lhs=*/true))
64 return op->emitError() << "requires all operands and results to have "
65 "compatible element types";
66 }
67 return success();
68 }
69 };
70
71 namespace detail {
verifySameOperandsAndResultElementTypeResolveRef(Operation * op)72 inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
73 Operation* op) {
74 Type element_type;
75 if (op->getNumResults() > 0) {
76 element_type = mlir::tf_type::GetElementTypeOrSelfResolveRef(
77 op->getResult(0).getType());
78 } else if (op->getNumOperands() > 0) {
79 element_type = mlir::tf_type::GetElementTypeOrSelfResolveRef(
80 op->getOperand(0).getType());
81 } else {
82 // Nothing to check.
83 return success();
84 }
85 // Verify that all result element types are compatible to `element_type`.
86 for (const auto& result_type : op->getResultTypes()) {
87 if (mlir::tf_type::GetElementTypeOrSelfResolveRef(result_type) !=
88 element_type) {
89 return op->emitOpError(
90 "requires compatible element types for all operands and results");
91 }
92 }
93 // Verify that all operand element types are compatible to `element_type`.
94 for (const auto& operand_type : op->getOperandTypes()) {
95 if (mlir::tf_type::GetElementTypeOrSelfResolveRef(operand_type) !=
96 element_type) {
97 return op->emitOpError(
98 "requires compatible element types for all operands and results");
99 }
100 }
101 return success();
102 }
103 } // namespace detail
104
105 // Verifies that op has the same operand and result element types (or type
106 // itself, if scalar) after resolving reference types (i.e., after converting
107 // reference types to their corresponding TensorFlow or standard types).
108 template <typename ConcreteType>
109 class SameOperandsAndResultElementTypeResolveRef
110 : public TraitBase<ConcreteType,
111 SameOperandsAndResultElementTypeResolveRef> {
112 public:
verifyTrait(Operation * op)113 static LogicalResult verifyTrait(Operation* op) {
114 return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
115 }
116 };
117
118 // Verifies that op has the same operand and result types after resolving
119 // reference types (i.e., after converting reference types to their
120 // corresponding TensorFlow or standard types).
121 template <typename ConcreteType>
122 class SameOperandsAndResultTypeResolveRef
123 : public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef> {
124 public:
verifyTrait(Operation * op)125 static LogicalResult verifyTrait(Operation* op) {
126 if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
127 return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
128 }
129 };
130
131 // Layout agnostic operations do not depend on the operands data layout (data
132 // format), as and example all element wise operations are layout agnostic.
133 template <typename ConcreteType>
134 class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic> {};
135
136 // Trait to indicate operations that cannot be duplicated as they might carry
137 // certain state around within their implementations.
138 template <typename ConcreteType>
139 class CannotDuplicate : public TraitBase<ConcreteType, CannotDuplicate> {
140 public:
verifyTrait(Operation * op)141 static LogicalResult verifyTrait(Operation* op) {
142 if (MemoryEffectOpInterface::hasNoEffect(op))
143 return op->emitError(
144 "operations with no side effects cannot have CannotDuplicate trait");
145 return success();
146 }
147 };
148
149 // Trait to indicate an operation cannot be constant folded.
150 template <typename ConcreteType>
151 class NoConstantFold : public TraitBase<ConcreteType, NoConstantFold> {};
152
153 // Coefficient-wise binary operation with implicit broadcasting support, for
154 // example tf.Sub operation.
155 template <typename ConcreteType>
156 class CwiseBinary : public TraitBase<ConcreteType, CwiseBinary> {};
157
158 // Coefficient-wise unary operation, for example tf.Sqrt operation.
159 template <typename ConcreteType>
160 class CwiseUnary : public TraitBase<ConcreteType, CwiseUnary> {};
161
162 // Indicates that any returned resource is unique.
163 template <typename ConcreteType>
164 class UniqueResourceAllocation
165 : public TraitBase<ConcreteType, UniqueResourceAllocation> {
166 public:
167 // Implements method required for `ResourceHandleAllocatorInterface`.
168 llvm::SmallVector<mlir::TF::ResourceHandleValueAndId>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<mlir::TF::ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)169 GetResourceHandleValueAndIdList(
170 llvm::SmallDenseMap<mlir::TF::ResourceHandle, int64_t>&
171 resource_handle_id_map,
172 int64_t& next_id) {
173 llvm::SmallVector<mlir::TF::ResourceHandleValueAndId> resource_vec;
174 for (Value resource :
175 mlir::tf_type::filter_resources(this->getOperation()->getResults())) {
176 resource_vec.push_back({resource, next_id++});
177 }
178 return resource_vec;
179 }
180 };
181
182 } // namespace TF
183 } // namespace OpTrait
184 } // namespace mlir
185
186 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
187