xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/utils/op_cat_helper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/utils/op_cat_helper.h"
17 
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/ir/dialect.h"
21 
22 namespace mlir {
23 namespace tfg {
24 
25 namespace {
26 
SplatElementsAttrHasValue(SplatElementsAttr attr,float v)27 bool SplatElementsAttrHasValue(SplatElementsAttr attr, float v) {
28   Type type = attr.getElementType();
29 
30 #define IF_SPLAT_VALUE_IS(DTYPE, VALUE)                                \
31   if (attr.getSplatValue<tensorflow::EnumToDataType<DTYPE>::Type>() == \
32       tensorflow::EnumToDataType<DTYPE>::Type(VALUE))                  \
33     return true;
34 
35   if (type.isInteger(1)) {
36     IF_SPLAT_VALUE_IS(tensorflow::DT_BOOL, v);
37   } else if (type.isSignedInteger()) {
38     if (type.isInteger(8)) {
39       IF_SPLAT_VALUE_IS(tensorflow::DT_INT8, v);
40     } else if (type.isInteger(16)) {
41       IF_SPLAT_VALUE_IS(tensorflow::DT_INT16, v);
42     } else if (type.isInteger(32)) {
43       IF_SPLAT_VALUE_IS(tensorflow::DT_INT32, v);
44     } else if (type.isInteger(64)) {
45       IF_SPLAT_VALUE_IS(tensorflow::DT_INT64, v);
46     }
47   } else if (type.isUnsignedInteger()) {
48     if (type.isInteger(8)) IF_SPLAT_VALUE_IS(tensorflow::DT_UINT8, v);
49     if (type.isInteger(16)) IF_SPLAT_VALUE_IS(tensorflow::DT_UINT16, v);
50   } else if (type.isF16()) {
51     IF_SPLAT_VALUE_IS(tensorflow::DT_HALF, v);
52   } else if (type.isF32()) {
53     IF_SPLAT_VALUE_IS(tensorflow::DT_FLOAT, v);
54   } else if (type.isF64()) {
55     IF_SPLAT_VALUE_IS(tensorflow::DT_DOUBLE, v);
56   } else if (type.isBF16()) {
57     IF_SPLAT_VALUE_IS(tensorflow::DT_BFLOAT16, v);
58   } else if (type.isa<ComplexType>()) {
59     ComplexType complex_type = type.cast<ComplexType>();
60     if (complex_type.getElementType().isF32()) {
61       IF_SPLAT_VALUE_IS(tensorflow::DT_COMPLEX64, v);
62     } else if (complex_type.getElementType().isF64()) {
63       IF_SPLAT_VALUE_IS(tensorflow::DT_COMPLEX128, v);
64     }
65   } else if (type.isa<tf_type::Qint8Type>()) {
66     IF_SPLAT_VALUE_IS(tensorflow::DT_QINT8, v);
67   } else if (type.isa<tf_type::Qint16Type>()) {
68     IF_SPLAT_VALUE_IS(tensorflow::DT_QINT16, v);
69   } else if (type.isa<tf_type::Qint32Type>()) {
70     IF_SPLAT_VALUE_IS(tensorflow::DT_QINT32, v);
71   } else if (type.isa<tf_type::Quint8Type>()) {
72     IF_SPLAT_VALUE_IS(tensorflow::DT_QUINT8, v);
73   } else if (type.isa<tf_type::Quint16Type>()) {
74     IF_SPLAT_VALUE_IS(tensorflow::DT_QUINT16, v);
75   }
76 #undef IF_SPLAT_VALUE_IS
77   return false;
78 }
79 
80 }  // namespace
81 
IsAggregate(TFOp op)82 bool OpCatHelper::IsAggregate(TFOp op) {
83   if (dialect_->IsAdd(op)) {
84     auto attr = op->getAttrOfType<TypeAttr>("T");
85     return !attr || !attr.getValue().isa<StringType>();
86   }
87   const tensorflow::OpDef *op_def = nullptr;
88   tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef(
89       op->getName().stripDialect().data(), &op_def);
90   return status.ok() && op_def->is_aggregate();
91 }
92 
IsCommutative(TFOp op)93 bool OpCatHelper::IsCommutative(TFOp op) {
94   if (dialect_->IsAdd(op)) {
95     auto attr = op->getAttrOfType<TypeAttr>("T");
96     return !attr || !attr.getValue().isa<StringType>();
97   }
98   const tensorflow::OpDef *op_def = nullptr;
99   tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef(
100       op->getName().stripDialect().data(), &op_def);
101   return status.ok() && op_def->is_commutative();
102 }
103 
IsOnes(TFOp op)104 bool OpCatHelper::IsOnes(TFOp op) {
105   if (dialect_->IsOnesLike(op)) return true;
106   if (dialect_->IsZerosLike(op)) return false;
107 
108   if (dialect_->IsFill(op)) {
109     TFOp value_op = op->getOperand(1).getDefiningOp();
110     return !value_op && IsOnes(value_op);
111   }
112 
113   if (!dialect_->IsConstant(op)) return false;
114 
115   SplatElementsAttr const_attr = op->getAttrOfType<SplatElementsAttr>("value");
116   if (!const_attr) return false;
117 
118   return SplatElementsAttrHasValue(const_attr, 1);
119 }
120 
IsZeros(TFOp op)121 bool OpCatHelper::IsZeros(TFOp op) {
122   if (dialect_->IsOnesLike(op)) return false;
123   if (dialect_->IsZerosLike(op)) return true;
124 
125   if (dialect_->IsFill(op)) {
126     TFOp value_op = op->getOperand(1).getDefiningOp();
127     return !value_op && IsZeros(value_op);
128   }
129 
130   if (!dialect_->IsConstant(op)) return false;
131 
132   SplatElementsAttr const_attr = op->getAttrOfType<SplatElementsAttr>("value");
133   if (!const_attr) return false;
134 
135   return SplatElementsAttrHasValue(const_attr, 0);
136 }
137 
IsPersistent(TFOp op)138 bool OpCatHelper::IsPersistent(TFOp op) {
139   return dialect_->IsConstant(op) || dialect_->IsVariable(op) ||
140          dialect_->IsHostConstant(op);
141 }
142 
IsDataset(TFOp op)143 bool OpCatHelper::IsDataset(TFOp op) {
144   static StringRef iterator_get_next = "IteratorGetNext";
145   static StringRef iterator_get_next_sync = "IteratorGetNextSync";
146   static StringRef dataset_to_single_element = "DatasetToSingleElement";
147   static StringRef reduce_data_set = "ReduceDataset";
148   StringRef op_name = op->getName().stripDialect();
149   // See `GetNodeClassForOp` in core/graph/graph.cc.
150   return op_name == iterator_get_next || op_name == iterator_get_next_sync ||
151          op_name == dataset_to_single_element || op_name == reduce_data_set;
152 }
153 
154 }  // namespace tfg
155 }  // namespace mlir
156