xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the op traits used in the MLIR TensorFlow dialect.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
19 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
20 
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
28 
29 namespace mlir {
30 namespace OpTrait {
31 namespace TF {
32 
33 // Verifies if 'ref_type' is a REF type corresponding to 'type'.
VerifyRefTypeMatch(mlir::Type type,mlir::Type maybe_ref_type)34 static inline LogicalResult VerifyRefTypeMatch(mlir::Type type,
35                                                mlir::Type maybe_ref_type) {
36   if (auto ref_type =
37           maybe_ref_type.dyn_cast<mlir::tf_type::TensorFlowRefType>())
38     return success(ref_type.RemoveRef().getTypeID() == type.getTypeID());
39   return failure();
40 }
41 
42 // This class provides verification for ops that are known to have the same
43 // result types and all operands are either of the same type as result or a REF
44 // type corresponding to the result type.
45 // TODO(jpienaar): Update the name and the description.
46 template <typename ConcreteType>
47 class OperandsSameAsResultsTypeOrRef
48     : public TraitBase<ConcreteType, OperandsSameAsResultsTypeOrRef> {
49  public:
verifyTrait(Operation * op)50   static LogicalResult verifyTrait(Operation* op) {
51     LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
52     if (failed(shapeMatch)) return shapeMatch;
53     Type type = op->getResult(0).getType();
54     // Verify that the first result type is same as the rest of the results.
55     // We skip the comparison against itself.
56     for (auto result_type : llvm::drop_begin(op->getResultTypes(), 1)) {
57       if (!mlir::tf_type::HasCompatibleElementTypes(type, result_type))
58         return op->emitOpError()
59                << "requires all return types to have compatible element types";
60     }
61     for (auto operand_type : op->getOperandTypes()) {
62       if (!mlir::tf_type::HasCompatibleElementTypes(
63               operand_type, type, /*may_ignore_ref_type_lhs=*/true))
64         return op->emitError() << "requires all operands and results to have "
65                                   "compatible element types";
66     }
67     return success();
68   }
69 };
70 
71 namespace detail {
verifySameOperandsAndResultElementTypeResolveRef(Operation * op)72 inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef(
73     Operation* op) {
74   Type element_type;
75   if (op->getNumResults() > 0) {
76     element_type = mlir::tf_type::GetElementTypeOrSelfResolveRef(
77         op->getResult(0).getType());
78   } else if (op->getNumOperands() > 0) {
79     element_type = mlir::tf_type::GetElementTypeOrSelfResolveRef(
80         op->getOperand(0).getType());
81   } else {
82     // Nothing to check.
83     return success();
84   }
85   // Verify that all result element types are compatible to `element_type`.
86   for (const auto& result_type : op->getResultTypes()) {
87     if (mlir::tf_type::GetElementTypeOrSelfResolveRef(result_type) !=
88         element_type) {
89       return op->emitOpError(
90           "requires compatible element types for all operands and results");
91     }
92   }
93   // Verify that all operand element types are compatible to `element_type`.
94   for (const auto& operand_type : op->getOperandTypes()) {
95     if (mlir::tf_type::GetElementTypeOrSelfResolveRef(operand_type) !=
96         element_type) {
97       return op->emitOpError(
98           "requires compatible element types for all operands and results");
99     }
100   }
101   return success();
102 }
103 }  // namespace detail
104 
105 // Verifies that op has the same operand and result element types (or type
106 // itself, if scalar) after resolving reference types (i.e., after converting
107 // reference types to their corresponding TensorFlow or standard types).
108 template <typename ConcreteType>
109 class SameOperandsAndResultElementTypeResolveRef
110     : public TraitBase<ConcreteType,
111                        SameOperandsAndResultElementTypeResolveRef> {
112  public:
verifyTrait(Operation * op)113   static LogicalResult verifyTrait(Operation* op) {
114     return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
115   }
116 };
117 
118 // Verifies that op has the same operand and result types after resolving
119 // reference types (i.e., after converting reference types to their
120 // corresponding TensorFlow or standard types).
121 template <typename ConcreteType>
122 class SameOperandsAndResultTypeResolveRef
123     : public TraitBase<ConcreteType, SameOperandsAndResultTypeResolveRef> {
124  public:
verifyTrait(Operation * op)125   static LogicalResult verifyTrait(Operation* op) {
126     if (failed(impl::verifySameOperandsAndResultShape(op))) return failure();
127     return detail::verifySameOperandsAndResultElementTypeResolveRef(op);
128   }
129 };
130 
131 // Layout agnostic operations do not depend on the operands data layout (data
132 // format), as and example all element wise operations are layout agnostic.
133 template <typename ConcreteType>
134 class LayoutAgnostic : public TraitBase<ConcreteType, LayoutAgnostic> {};
135 
136 // Trait to indicate operations that cannot be duplicated as they might carry
137 // certain state around within their implementations.
138 template <typename ConcreteType>
139 class CannotDuplicate : public TraitBase<ConcreteType, CannotDuplicate> {
140  public:
verifyTrait(Operation * op)141   static LogicalResult verifyTrait(Operation* op) {
142     if (MemoryEffectOpInterface::hasNoEffect(op))
143       return op->emitError(
144           "operations with no side effects cannot have CannotDuplicate trait");
145     return success();
146   }
147 };
148 
149 // Trait to indicate an operation cannot be constant folded.
150 template <typename ConcreteType>
151 class NoConstantFold : public TraitBase<ConcreteType, NoConstantFold> {};
152 
153 // Coefficient-wise binary operation with implicit broadcasting support, for
154 // example tf.Sub operation.
155 template <typename ConcreteType>
156 class CwiseBinary : public TraitBase<ConcreteType, CwiseBinary> {};
157 
158 // Coefficient-wise unary operation, for example tf.Sqrt operation.
159 template <typename ConcreteType>
160 class CwiseUnary : public TraitBase<ConcreteType, CwiseUnary> {};
161 
162 // Indicates that any returned resource is unique.
163 template <typename ConcreteType>
164 class UniqueResourceAllocation
165     : public TraitBase<ConcreteType, UniqueResourceAllocation> {
166  public:
167   // Implements method required for `ResourceHandleAllocatorInterface`.
168   llvm::SmallVector<mlir::TF::ResourceHandleValueAndId>
GetResourceHandleValueAndIdList(llvm::SmallDenseMap<mlir::TF::ResourceHandle,int64_t> & resource_handle_id_map,int64_t & next_id)169   GetResourceHandleValueAndIdList(
170       llvm::SmallDenseMap<mlir::TF::ResourceHandle, int64_t>&
171           resource_handle_id_map,
172       int64_t& next_id) {
173     llvm::SmallVector<mlir::TF::ResourceHandleValueAndId> resource_vec;
174     for (Value resource :
175          mlir::tf_type::filter_resources(this->getOperation()->getResults())) {
176       resource_vec.push_back({resource, next_id++});
177     }
178     return resource_vec;
179   }
180 };
181 
182 }  // namespace TF
183 }  // namespace OpTrait
184 }  // namespace mlir
185 
186 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_
187