1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/utils/op_cat_helper.h"
17
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/ir/dialect.h"
21
22 namespace mlir {
23 namespace tfg {
24
25 namespace {
26
SplatElementsAttrHasValue(SplatElementsAttr attr,float v)27 bool SplatElementsAttrHasValue(SplatElementsAttr attr, float v) {
28 Type type = attr.getElementType();
29
30 #define IF_SPLAT_VALUE_IS(DTYPE, VALUE) \
31 if (attr.getSplatValue<tensorflow::EnumToDataType<DTYPE>::Type>() == \
32 tensorflow::EnumToDataType<DTYPE>::Type(VALUE)) \
33 return true;
34
35 if (type.isInteger(1)) {
36 IF_SPLAT_VALUE_IS(tensorflow::DT_BOOL, v);
37 } else if (type.isSignedInteger()) {
38 if (type.isInteger(8)) {
39 IF_SPLAT_VALUE_IS(tensorflow::DT_INT8, v);
40 } else if (type.isInteger(16)) {
41 IF_SPLAT_VALUE_IS(tensorflow::DT_INT16, v);
42 } else if (type.isInteger(32)) {
43 IF_SPLAT_VALUE_IS(tensorflow::DT_INT32, v);
44 } else if (type.isInteger(64)) {
45 IF_SPLAT_VALUE_IS(tensorflow::DT_INT64, v);
46 }
47 } else if (type.isUnsignedInteger()) {
48 if (type.isInteger(8)) IF_SPLAT_VALUE_IS(tensorflow::DT_UINT8, v);
49 if (type.isInteger(16)) IF_SPLAT_VALUE_IS(tensorflow::DT_UINT16, v);
50 } else if (type.isF16()) {
51 IF_SPLAT_VALUE_IS(tensorflow::DT_HALF, v);
52 } else if (type.isF32()) {
53 IF_SPLAT_VALUE_IS(tensorflow::DT_FLOAT, v);
54 } else if (type.isF64()) {
55 IF_SPLAT_VALUE_IS(tensorflow::DT_DOUBLE, v);
56 } else if (type.isBF16()) {
57 IF_SPLAT_VALUE_IS(tensorflow::DT_BFLOAT16, v);
58 } else if (type.isa<ComplexType>()) {
59 ComplexType complex_type = type.cast<ComplexType>();
60 if (complex_type.getElementType().isF32()) {
61 IF_SPLAT_VALUE_IS(tensorflow::DT_COMPLEX64, v);
62 } else if (complex_type.getElementType().isF64()) {
63 IF_SPLAT_VALUE_IS(tensorflow::DT_COMPLEX128, v);
64 }
65 } else if (type.isa<tf_type::Qint8Type>()) {
66 IF_SPLAT_VALUE_IS(tensorflow::DT_QINT8, v);
67 } else if (type.isa<tf_type::Qint16Type>()) {
68 IF_SPLAT_VALUE_IS(tensorflow::DT_QINT16, v);
69 } else if (type.isa<tf_type::Qint32Type>()) {
70 IF_SPLAT_VALUE_IS(tensorflow::DT_QINT32, v);
71 } else if (type.isa<tf_type::Quint8Type>()) {
72 IF_SPLAT_VALUE_IS(tensorflow::DT_QUINT8, v);
73 } else if (type.isa<tf_type::Quint16Type>()) {
74 IF_SPLAT_VALUE_IS(tensorflow::DT_QUINT16, v);
75 }
76 #undef IF_SPLAT_VALUE_IS
77 return false;
78 }
79
80 } // namespace
81
IsAggregate(TFOp op)82 bool OpCatHelper::IsAggregate(TFOp op) {
83 if (dialect_->IsAdd(op)) {
84 auto attr = op->getAttrOfType<TypeAttr>("T");
85 return !attr || !attr.getValue().isa<StringType>();
86 }
87 const tensorflow::OpDef *op_def = nullptr;
88 tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef(
89 op->getName().stripDialect().data(), &op_def);
90 return status.ok() && op_def->is_aggregate();
91 }
92
IsCommutative(TFOp op)93 bool OpCatHelper::IsCommutative(TFOp op) {
94 if (dialect_->IsAdd(op)) {
95 auto attr = op->getAttrOfType<TypeAttr>("T");
96 return !attr || !attr.getValue().isa<StringType>();
97 }
98 const tensorflow::OpDef *op_def = nullptr;
99 tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef(
100 op->getName().stripDialect().data(), &op_def);
101 return status.ok() && op_def->is_commutative();
102 }
103
IsOnes(TFOp op)104 bool OpCatHelper::IsOnes(TFOp op) {
105 if (dialect_->IsOnesLike(op)) return true;
106 if (dialect_->IsZerosLike(op)) return false;
107
108 if (dialect_->IsFill(op)) {
109 TFOp value_op = op->getOperand(1).getDefiningOp();
110 return !value_op && IsOnes(value_op);
111 }
112
113 if (!dialect_->IsConstant(op)) return false;
114
115 SplatElementsAttr const_attr = op->getAttrOfType<SplatElementsAttr>("value");
116 if (!const_attr) return false;
117
118 return SplatElementsAttrHasValue(const_attr, 1);
119 }
120
IsZeros(TFOp op)121 bool OpCatHelper::IsZeros(TFOp op) {
122 if (dialect_->IsOnesLike(op)) return false;
123 if (dialect_->IsZerosLike(op)) return true;
124
125 if (dialect_->IsFill(op)) {
126 TFOp value_op = op->getOperand(1).getDefiningOp();
127 return !value_op && IsZeros(value_op);
128 }
129
130 if (!dialect_->IsConstant(op)) return false;
131
132 SplatElementsAttr const_attr = op->getAttrOfType<SplatElementsAttr>("value");
133 if (!const_attr) return false;
134
135 return SplatElementsAttrHasValue(const_attr, 0);
136 }
137
IsPersistent(TFOp op)138 bool OpCatHelper::IsPersistent(TFOp op) {
139 return dialect_->IsConstant(op) || dialect_->IsVariable(op) ||
140 dialect_->IsHostConstant(op);
141 }
142
IsDataset(TFOp op)143 bool OpCatHelper::IsDataset(TFOp op) {
144 static StringRef iterator_get_next = "IteratorGetNext";
145 static StringRef iterator_get_next_sync = "IteratorGetNextSync";
146 static StringRef dataset_to_single_element = "DatasetToSingleElement";
147 static StringRef reduce_data_set = "ReduceDataset";
148 StringRef op_name = op->getName().stripDialect();
149 // See `GetNodeClassForOp` in core/graph/graph.cc.
150 return op_name == iterator_get_next || op_name == iterator_get_next_sync ||
151 op_name == dataset_to_single_element || op_name == reduce_data_set;
152 }
153
154 } // namespace tfg
155 } // namespace mlir
156