1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "mlir/IR/Operation.h" // from @llvm-project 17 #include "mlir/IR/OperationSupport.h" // from @llvm-project 18 #include "tensorflow/core/ir/dialect.h" 19 #include "tensorflow/core/ir/tf_op_wrapper.h" 20 21 namespace mlir { 22 namespace tfg { 23 IsAdd(TFOp op) const24bool TFGraphDialect::IsAdd(TFOp op) const { 25 StringAttr op_name = op->getName().getIdentifier(); 26 27 if (op_name == add_v2_) return true; 28 if (op_name == add_) return !op->getAttrOfType<StringAttr>("T"); 29 return false; 30 } 31 IsAddN(TFOp op) const32bool TFGraphDialect::IsAddN(TFOp op) const { 33 StringAttr op_name = op->getName().getIdentifier(); 34 return op_name == add_n_; 35 } 36 IsAll(TFOp op) const37bool TFGraphDialect::IsAll(TFOp op) const { 38 StringAttr op_name = op->getName().getIdentifier(); 39 return op_name == all_; 40 } 41 IsAngle(TFOp op) const42bool TFGraphDialect::IsAngle(TFOp op) const { 43 StringAttr op_name = op->getName().getIdentifier(); 44 return op_name == angle_; 45 } 46 IsAny(TFOp op) const47bool TFGraphDialect::IsAny(TFOp op) const { 48 StringAttr op_name = op->getName().getIdentifier(); 49 return op_name == any_; 50 } 51 IsAnyDiv(TFOp op) const52bool TFGraphDialect::IsAnyDiv(TFOp op) const { 53 StringAttr op_name = op->getName().getIdentifier(); 54 return op_name == real_div_ || op_name == div_ || IsXdivy(op) || 55 op_name == floor_div_ || op_name == truncate_div_; 56 } 57 IsAnyBatchMatMul(TFOp op) const58bool TFGraphDialect::IsAnyBatchMatMul(TFOp op) const { 59 StringAttr op_name = op->getName().getIdentifier(); 60 return op_name == batch_matmul_ || op_name == batch_matmul_v2_; 61 } 62 IsAnyMatMul(TFOp op) const63bool TFGraphDialect::IsAnyMatMul(TFOp op) const { 64 StringAttr op_name = op->getName().getIdentifier(); 65 return op_name == matmul_ || op_name == sparse_matmul_ || 66 IsAnyBatchMatMul(op) || IsQuantizedMatMul(op); 67 } 68 IsAnyMax(TFOp op) const69bool TFGraphDialect::IsAnyMax(TFOp op) const { 70 StringAttr op_name = op->getName().getIdentifier(); 71 return op_name == max_ || op_name == segment_max_ || 72 op_name == unsorted_segment_max_; 73 } 74 IsAnyMaxPool(TFOp op) const75bool TFGraphDialect::IsAnyMaxPool(TFOp op) const { 76 StringAttr op_name = op->getName().getIdentifier(); 77 return op_name == max_pool_ || op_name == max_pool_v2_ || 78 op_name == max_pool_3d_ || op_name == max_pool_with_argmax_ || 79 op_name == fractional_max_pool_; 80 } 81 IsAnyMin(TFOp op) const82bool TFGraphDialect::IsAnyMin(TFOp op) const { 83 StringAttr op_name = op->getName().getIdentifier(); 84 return op_name == min_ || op_name == segment_min_ || 85 op_name == unsorted_segment_min_; 86 } 87 IsAnySparseSegmentReduction(TFOp op) const88bool TFGraphDialect::IsAnySparseSegmentReduction(TFOp op) const { 89 StringAttr op_name = op->getName().getIdentifier(); 90 return op_name == sparse_segment_sum_ || 91 op_name == sparse_segment_sum_with_num_segments_ || 92 op_name == sparse_segment_mean_ || 93 op_name == sparse_segment_mean_with_num_segments_ || 94 op_name == sparse_segment_sqrtn_ || 95 op_name == sparse_segment_sqrtn_with_num_segments_; 96 } 97 IsApproximateEqual(TFOp op) const98bool TFGraphDialect::IsApproximateEqual(TFOp op) const { 99 StringAttr op_name = op->getName().getIdentifier(); 100 return op_name == approximate_equal_; 101 } 102 IsArg(TFOp op) const103bool TFGraphDialect::IsArg(TFOp op) const { 104 StringAttr op_name = op->getName().getIdentifier(); 105 return op_name == arg_ || op_name == device_arg_; 106 } 107 IsArgMax(TFOp op) const108bool TFGraphDialect::IsArgMax(TFOp op) const { 109 StringAttr op_name = op->getName().getIdentifier(); 110 return op_name == arg_max_; 111 } 112 IsArgMin(TFOp op) const113bool TFGraphDialect::IsArgMin(TFOp op) const { 114 StringAttr op_name = op->getName().getIdentifier(); 115 return op_name == arg_min_; 116 } 117 IsAvgPoolGrad(TFOp op) const118bool TFGraphDialect::IsAvgPoolGrad(TFOp op) const { 119 StringAttr op_name = op->getName().getIdentifier(); 120 return op_name == arg_pool_grad_; 121 } 122 IsAssign(TFOp op) const123bool TFGraphDialect::IsAssign(TFOp op) const { 124 StringAttr op_name = op->getName().getIdentifier(); 125 return op_name == assign_ || op_name == assign_variable_op_; 126 } 127 IsAssert(TFOp op) const128bool TFGraphDialect::IsAssert(TFOp op) const { 129 StringAttr op_name = op->getName().getIdentifier(); 130 return op_name == assert_; 131 } 132 IsAsString(TFOp op) const133bool TFGraphDialect::IsAsString(TFOp op) const { 134 StringAttr op_name = op->getName().getIdentifier(); 135 return op_name == as_string_; 136 } 137 IsAtan2(TFOp op) const138bool TFGraphDialect::IsAtan2(TFOp op) const { 139 StringAttr op_name = op->getName().getIdentifier(); 140 return op_name == atan2_; 141 } 142 IsBetainc(TFOp op) const143bool TFGraphDialect::IsBetainc(TFOp op) const { 144 StringAttr op_name = op->getName().getIdentifier(); 145 return op_name == betainc_; 146 } 147 IsBiasAdd(TFOp op) const148bool TFGraphDialect::IsBiasAdd(TFOp op) const { 149 StringAttr op_name = op->getName().getIdentifier(); 150 return op_name == bias_add_ || op_name == bias_add_v1_; 151 } 152 IsBiasAddV2(TFOp op) const153bool TFGraphDialect::IsBiasAddV2(TFOp op) const { 154 StringAttr op_name = op->getName().getIdentifier(); 155 return op_name == bias_add_; 156 } 157 IsBiasAddGrad(TFOp op) const158bool TFGraphDialect::IsBiasAddGrad(TFOp op) const { 159 StringAttr op_name = op->getName().getIdentifier(); 160 return op_name == bias_add_grad_; 161 } 162 IsBitcast(TFOp op) const163bool TFGraphDialect::IsBitcast(TFOp op) const { 164 StringAttr op_name = op->getName().getIdentifier(); 165 return op_name == bitcast_; 166 } 167 IsBroadcastTo(TFOp op) const168bool TFGraphDialect::IsBroadcastTo(TFOp op) const { 169 StringAttr op_name = op->getName().getIdentifier(); 170 return op_name == broadcast_to_; 171 } 172 IsCast(TFOp op) const173bool TFGraphDialect::IsCast(TFOp op) const { 174 StringAttr op_name = op->getName().getIdentifier(); 175 return op_name == cast_; 176 } 177 IsCheckNumerics(TFOp op) const178bool TFGraphDialect::IsCheckNumerics(TFOp op) const { 179 StringAttr op_name = op->getName().getIdentifier(); 180 return op_name == check_numerics_; 181 } 182 IsCollective(TFOp op) const183bool TFGraphDialect::IsCollective(TFOp op) const { 184 StringAttr op_name = op->getName().getIdentifier(); 185 return op_name == collective_reduce_ || op_name == collective_bcast_send_ || 186 op_name == collective_bcast_recv_; 187 } 188 IsComplex(TFOp op) const189bool TFGraphDialect::IsComplex(TFOp op) const { 190 StringAttr op_name = op->getName().getIdentifier(); 191 return op_name == complex_; 192 } 193 IsComplexAbs(TFOp op) const194bool TFGraphDialect::IsComplexAbs(TFOp op) const { 195 StringAttr op_name = op->getName().getIdentifier(); 196 return op_name == complex_abs_; 197 } 198 IsConcat(TFOp op) const199bool TFGraphDialect::IsConcat(TFOp op) const { 200 StringAttr op_name = op->getName().getIdentifier(); 201 return op_name == concat_ || IsConcatV2(op); 202 } 203 IsConcatV2(TFOp op) const204bool TFGraphDialect::IsConcatV2(TFOp op) const { 205 StringAttr op_name = op->getName().getIdentifier(); 206 return op_name == concat_v2_; 207 } 208 IsConcatOffset(TFOp op) const209bool TFGraphDialect::IsConcatOffset(TFOp op) const { 210 StringAttr op_name = op->getName().getIdentifier(); 211 return op_name == concat_offset_; 212 } 213 IsConstant(TFOp op) const214bool TFGraphDialect::IsConstant(TFOp op) const { 215 StringAttr op_name = op->getName().getIdentifier(); 216 return op_name == const_; 217 } 218 IsConj(TFOp op) const219bool TFGraphDialect::IsConj(TFOp op) const { 220 StringAttr op_name = op->getName().getIdentifier(); 221 return op_name == conj_; 222 } 223 IsConjugateTranspose(TFOp op) const224bool TFGraphDialect::IsConjugateTranspose(TFOp op) const { 225 StringAttr op_name = op->getName().getIdentifier(); 226 return op_name == conjugate_transpose_; 227 } 228 229 // TODO(chiahungduan): Should we use certain helpers like IsEnter(). IsControlFlow(TFOp op) const230bool TFGraphDialect::IsControlFlow(TFOp op) const { 231 StringAttr op_name = op->getName().getIdentifier(); 232 233 return op_name == control_trigger_ || op_name == enter_ || op_name == exit_ || 234 op_name == loop_cond_ || op_name == merge_ || op_name == xla_merge_ || 235 op_name == next_iteration_ || op_name == switch_ || 236 op_name == switch_n_; 237 } 238 IsConv2D(TFOp op) const239bool TFGraphDialect::IsConv2D(TFOp op) const { 240 StringAttr op_name = op->getName().getIdentifier(); 241 return op_name == conv_2d_; 242 } 243 IsConv2DBackpropFilter(TFOp op) const244bool TFGraphDialect::IsConv2DBackpropFilter(TFOp op) const { 245 StringAttr op_name = op->getName().getIdentifier(); 246 return op_name == conv_2d_back_prop_filter_; 247 } 248 IsConv2DBackpropInput(TFOp op) const249bool TFGraphDialect::IsConv2DBackpropInput(TFOp op) const { 250 StringAttr op_name = op->getName().getIdentifier(); 251 return op_name == conv_2d_back_prop_input_; 252 } 253 IsConv3D(TFOp op) const254bool TFGraphDialect::IsConv3D(TFOp op) const { 255 StringAttr op_name = op->getName().getIdentifier(); 256 return op_name == conv_3d_; 257 } 258 IsConv3DBackpropFilterV2(TFOp op) const259bool TFGraphDialect::IsConv3DBackpropFilterV2(TFOp op) const { 260 StringAttr op_name = op->getName().getIdentifier(); 261 return op_name == conv_3d_back_prop_filter_v2_; 262 } 263 IsConv3DBackpropInputV2(TFOp op) const264bool TFGraphDialect::IsConv3DBackpropInputV2(TFOp op) const { 265 StringAttr op_name = op->getName().getIdentifier(); 266 return op_name == conv_3d_back_prop_input_v2_; 267 } 268 IsDepthwiseConv2dNative(TFOp op) const269bool TFGraphDialect::IsDepthwiseConv2dNative(TFOp op) const { 270 StringAttr op_name = op->getName().getIdentifier(); 271 return op_name == depth_wise_conv_2d_native_; 272 } 273 IsDepthwiseConv2dNativeBackpropFilter(TFOp op) const274bool TFGraphDialect::IsDepthwiseConv2dNativeBackpropFilter(TFOp op) const { 275 StringAttr op_name = op->getName().getIdentifier(); 276 return op_name == depth_wise_conv_2d_native_back_prop_filter_; 277 } 278 IsDepthwiseConv2dNativeBackpropInput(TFOp op) const279bool TFGraphDialect::IsDepthwiseConv2dNativeBackpropInput(TFOp op) const { 280 StringAttr op_name = op->getName().getIdentifier(); 281 return op_name == depth_wise_conv_2d_native_back_prop_input_; 282 } 283 IsDequeueOp(TFOp op) const284bool TFGraphDialect::IsDequeueOp(TFOp op) const { 285 StringAttr op_name = op->getName().getIdentifier(); 286 return op_name == queue_dequeue_ || op_name == queue_dequeue_v2_ || 287 op_name == queue_dequeue_many_ || op_name == queue_dequeue_many_v2_ || 288 op_name == queue_dequeue_upto_ || op_name == queue_dequeue_upto_v2_; 289 } 290 IsDiv(TFOp op) const291bool TFGraphDialect::IsDiv(TFOp op) const { 292 StringAttr op_name = op->getName().getIdentifier(); 293 return op_name == div_; 294 } 295 IsDivNoNan(TFOp op) const296bool TFGraphDialect::IsDivNoNan(TFOp op) const { 297 StringAttr op_name = op->getName().getIdentifier(); 298 return op_name == div_no_nan_; 299 } 300 IsElu(TFOp op) const301bool TFGraphDialect::IsElu(TFOp op) const { 302 StringAttr op_name = op->getName().getIdentifier(); 303 return op_name == elu_; 304 } 305 IsEluGrad(TFOp op) const306bool TFGraphDialect::IsEluGrad(TFOp op) const { 307 StringAttr op_name = op->getName().getIdentifier(); 308 return op_name == elu_grad_; 309 } 310 IsQuantizationEmulation(TFOp op) const311bool TFGraphDialect::IsQuantizationEmulation(TFOp op) const { 312 StringAttr op_name = op->getName().getIdentifier(); 313 return op_name == quantize_and_dequantize_ || 314 op_name == quantize_and_dequantize_v2_ || 315 op_name == quantize_and_dequantize_v3_ || 316 op_name == quantize_and_dequantize_v4_ || 317 op_name == quantize_and_dequantize_v4_grad_ || 318 op_name == fake_quant_with_min_max_args_ || 319 op_name == fake_quant_with_min_max_args_gradient_ || 320 op_name == fake_quant_with_min_max_vars_ || 321 op_name == fake_quant_with_min_max_vars_gradient_ || 322 op_name == fake_quant_with_min_max_vars_per_channel_ || 323 op_name == fake_quant_with_min_max_vars_per_channel_gradient_; 324 } 325 IsEnter(TFOp op) const326bool TFGraphDialect::IsEnter(TFOp op) const { 327 StringAttr op_name = op->getName().getIdentifier(); 328 return op_name == enter_ || op_name == ref_enter_; 329 } 330 IsEqual(TFOp op) const331bool TFGraphDialect::IsEqual(TFOp op) const { 332 StringAttr op_name = op->getName().getIdentifier(); 333 return op_name == equal_; 334 } 335 IsExit(TFOp op) const336bool TFGraphDialect::IsExit(TFOp op) const { 337 StringAttr op_name = op->getName().getIdentifier(); 338 return op_name == exit_ || op_name == ref_exit_; 339 } 340 IsExp(TFOp op) const341bool TFGraphDialect::IsExp(TFOp op) const { 342 StringAttr op_name = op->getName().getIdentifier(); 343 return op_name == exp_; 344 } 345 IsFakeParam(TFOp op) const346bool TFGraphDialect::IsFakeParam(TFOp op) const { 347 StringAttr op_name = op->getName().getIdentifier(); 348 return op_name == fake_param_; 349 } 350 IsFill(TFOp op) const351bool TFGraphDialect::IsFill(TFOp op) const { 352 StringAttr op_name = op->getName().getIdentifier(); 353 return op_name == fill_; 354 } 355 IsFloorDiv(TFOp op) const356bool TFGraphDialect::IsFloorDiv(TFOp op) const { 357 StringAttr op_name = op->getName().getIdentifier(); 358 return op_name == floor_div_; 359 } 360 IsFloorMod(TFOp op) const361bool TFGraphDialect::IsFloorMod(TFOp op) const { 362 StringAttr op_name = op->getName().getIdentifier(); 363 return op_name == floor_mod_; 364 } 365 IsFusedBatchNorm(TFOp op) const366bool TFGraphDialect::IsFusedBatchNorm(TFOp op) const { 367 StringAttr op_name = op->getName().getIdentifier(); 368 return op_name == fused_batch_norm_ || op_name == fused_batch_norm_v2_ || 369 op_name == fused_batch_norm_v3_; 370 } 371 IsFusedBatchNormEx(TFOp op) const372bool TFGraphDialect::IsFusedBatchNormEx(TFOp op) const { 373 StringAttr op_name = op->getName().getIdentifier(); 374 return op_name == fused_batch_norm_ex_; 375 } 376 IsFusedBatchNormGrad(TFOp op) const377bool TFGraphDialect::IsFusedBatchNormGrad(TFOp op) const { 378 StringAttr op_name = op->getName().getIdentifier(); 379 return op_name == fused_batch_norm_grad_ || 380 op_name == fused_batch_norm_grad_v2_ || 381 op_name == fused_batch_norm_grad_v3_; 382 } 383 IsGather(TFOp op) const384bool TFGraphDialect::IsGather(TFOp op) const { 385 StringAttr op_name = op->getName().getIdentifier(); 386 return op_name == gather_ || op_name == gather_v2_ || 387 op_name == resource_gather_; 388 } 389 IsGreater(TFOp op) const390bool TFGraphDialect::IsGreater(TFOp op) const { 391 StringAttr op_name = op->getName().getIdentifier(); 392 return op_name == greater_; 393 } 394 IsGreaterEqual(TFOp op) const395bool TFGraphDialect::IsGreaterEqual(TFOp op) const { 396 StringAttr op_name = op->getName().getIdentifier(); 397 return op_name == greater_equal_; 398 } 399 IsHostConstant(TFOp op) const400bool TFGraphDialect::IsHostConstant(TFOp op) const { 401 StringAttr op_name = op->getName().getIdentifier(); 402 return op_name == host_const_; 403 } 404 IsHistogramSummary(TFOp op) const405bool TFGraphDialect::IsHistogramSummary(TFOp op) const { 406 StringAttr op_name = op->getName().getIdentifier(); 407 return op_name == histogram_summary_; 408 } 409 IsIdentity(TFOp op) const410bool TFGraphDialect::IsIdentity(TFOp op) const { 411 StringAttr op_name = op->getName().getIdentifier(); 412 return op_name == identity_ || op_name == ref_identity_; 413 } 414 IsIdentityN(TFOp op) const415bool TFGraphDialect::IsIdentityN(TFOp op) const { 416 StringAttr op_name = op->getName().getIdentifier(); 417 return op_name == identity_n_; 418 } 419 IsIdentityNSingleInput(TFOp op) const420bool TFGraphDialect::IsIdentityNSingleInput(TFOp op) const { 421 if (!IsIdentityN(op)) return false; 422 auto array_attr = op->getAttrOfType<ArrayAttr>("T"); 423 if (!array_attr) return false; 424 // TODO(chiahungduan): Do we need to check the content of array_attr? 425 return array_attr.size() == 1; 426 } 427 IsIf(TFOp op) const428bool TFGraphDialect::IsIf(TFOp op) const { 429 StringAttr op_name = op->getName().getIdentifier(); 430 return op_name == if_ || op_name == stateless_if_; 431 } 432 IsIgamma(TFOp op) const433bool TFGraphDialect::IsIgamma(TFOp op) const { 434 StringAttr op_name = op->getName().getIdentifier(); 435 return op_name == igamma_; 436 } 437 IsIgammac(TFOp op) const438bool TFGraphDialect::IsIgammac(TFOp op) const { 439 StringAttr op_name = op->getName().getIdentifier(); 440 return op_name == igammac_; 441 } 442 IsImag(TFOp op) const443bool TFGraphDialect::IsImag(TFOp op) const { 444 StringAttr op_name = op->getName().getIdentifier(); 445 return op_name == imag_; 446 } 447 IsImmutableConst(TFOp op) const448bool TFGraphDialect::IsImmutableConst(TFOp op) const { 449 StringAttr op_name = op->getName().getIdentifier(); 450 return op_name == immutable_const_; 451 } 452 IsInvGrad(TFOp op) const453bool TFGraphDialect::IsInvGrad(TFOp op) const { 454 StringAttr op_name = op->getName().getIdentifier(); 455 return op_name == inv_grad_; 456 } 457 IsLeakyRelu(TFOp op) const458bool TFGraphDialect::IsLeakyRelu(TFOp op) const { 459 StringAttr op_name = op->getName().getIdentifier(); 460 return op_name == leaky_relu_; 461 } 462 IsLeakyReluGrad(TFOp op) const463bool TFGraphDialect::IsLeakyReluGrad(TFOp op) const { 464 StringAttr op_name = op->getName().getIdentifier(); 465 return op_name == leaky_relu_grad_; 466 } 467 IsLess(TFOp op) const468bool TFGraphDialect::IsLess(TFOp op) const { 469 StringAttr op_name = op->getName().getIdentifier(); 470 return op_name == less_; 471 } 472 IsLessEqual(TFOp op) const473bool TFGraphDialect::IsLessEqual(TFOp op) const { 474 StringAttr op_name = op->getName().getIdentifier(); 475 return op_name == less_equal_; 476 } 477 IsLog(TFOp op) const478bool TFGraphDialect::IsLog(TFOp op) const { 479 StringAttr op_name = op->getName().getIdentifier(); 480 return op_name == log_; 481 } 482 IsLogicalAnd(TFOp op) const483bool TFGraphDialect::IsLogicalAnd(TFOp op) const { 484 StringAttr op_name = op->getName().getIdentifier(); 485 return op_name == logical_and_; 486 } 487 IsLogicalNot(TFOp op) const488bool TFGraphDialect::IsLogicalNot(TFOp op) const { 489 StringAttr op_name = op->getName().getIdentifier(); 490 return op_name == logical_not_; 491 } 492 IsLogicalOr(TFOp op) const493bool TFGraphDialect::IsLogicalOr(TFOp op) const { 494 StringAttr op_name = op->getName().getIdentifier(); 495 return op_name == logical_or_; 496 } 497 IsLoopCond(TFOp op) const498bool TFGraphDialect::IsLoopCond(TFOp op) const { 499 StringAttr op_name = op->getName().getIdentifier(); 500 return op_name == loop_cond_; 501 } 502 IsMatMul(TFOp op) const503bool TFGraphDialect::IsMatMul(TFOp op) const { 504 StringAttr op_name = op->getName().getIdentifier(); 505 return op_name == matmul_; 506 } 507 IsMax(TFOp op) const508bool TFGraphDialect::IsMax(TFOp op) const { 509 StringAttr op_name = op->getName().getIdentifier(); 510 return op_name == max_; 511 } 512 IsMaximum(TFOp op) const513bool TFGraphDialect::IsMaximum(TFOp op) const { 514 StringAttr op_name = op->getName().getIdentifier(); 515 return op_name == maximum_; 516 } 517 IsMaxPoolGrad(TFOp op) const518bool TFGraphDialect::IsMaxPoolGrad(TFOp op) const { 519 StringAttr op_name = op->getName().getIdentifier(); 520 return op_name == max_pool_grad_; 521 } 522 IsMean(TFOp op) const523bool TFGraphDialect::IsMean(TFOp op) const { 524 StringAttr op_name = op->getName().getIdentifier(); 525 return op_name == mean_; 526 } 527 IsMerge(TFOp op) const528bool TFGraphDialect::IsMerge(TFOp op) const { 529 StringAttr op_name = op->getName().getIdentifier(); 530 return op_name == merge_ || op_name == ref_merge_ || op_name == xla_merge_; 531 } 532 IsMin(TFOp op) const533bool TFGraphDialect::IsMin(TFOp op) const { 534 StringAttr op_name = op->getName().getIdentifier(); 535 return op_name == min_; 536 } 537 IsMinimum(TFOp op) const538bool TFGraphDialect::IsMinimum(TFOp op) const { 539 StringAttr op_name = op->getName().getIdentifier(); 540 return op_name == minimum_; 541 } 542 IsMirrorPad(TFOp op) const543bool TFGraphDialect::IsMirrorPad(TFOp op) const { 544 StringAttr op_name = op->getName().getIdentifier(); 545 return op_name == mirror_pad_; 546 } 547 IsMirrorPadGrad(TFOp op) const548bool TFGraphDialect::IsMirrorPadGrad(TFOp op) const { 549 StringAttr op_name = op->getName().getIdentifier(); 550 return op_name == mirror_pad_grad_; 551 } 552 IsMod(TFOp op) const553bool TFGraphDialect::IsMod(TFOp op) const { 554 StringAttr op_name = op->getName().getIdentifier(); 555 return op_name == mod_; 556 } 557 IsMul(TFOp op) const558bool TFGraphDialect::IsMul(TFOp op) const { 559 StringAttr op_name = op->getName().getIdentifier(); 560 return op_name == mul_; 561 } IsMulNoNan(TFOp op) const562bool TFGraphDialect::IsMulNoNan(TFOp op) const { 563 StringAttr op_name = op->getName().getIdentifier(); 564 return op_name == mul_no_nan_; 565 } IsAnyMul(TFOp op) const566bool TFGraphDialect::IsAnyMul(TFOp op) const { 567 return IsMul(op) || IsMulNoNan(op); 568 } 569 IsNeg(TFOp op) const570bool TFGraphDialect::IsNeg(TFOp op) const { 571 StringAttr op_name = op->getName().getIdentifier(); 572 return op_name == neg_; 573 } 574 IsNoOp(TFOp op) const575bool TFGraphDialect::IsNoOp(TFOp op) const { 576 StringAttr op_name = op->getName().getIdentifier(); 577 return op_name == no_op_; 578 } 579 IsNotEqual(TFOp op) const580bool TFGraphDialect::IsNotEqual(TFOp op) const { 581 StringAttr op_name = op->getName().getIdentifier(); 582 return op_name == not_equal_; 583 } 584 IsNextIteration(TFOp op) const585bool TFGraphDialect::IsNextIteration(TFOp op) const { 586 StringAttr op_name = op->getName().getIdentifier(); 587 return op_name == next_iteration_ || op_name == ref_next_iteration_; 588 } 589 IsOnesLike(TFOp op) const590bool TFGraphDialect::IsOnesLike(TFOp op) const { 591 StringAttr op_name = op->getName().getIdentifier(); 592 return op_name == ones_like_; 593 } 594 IsPack(TFOp op) const595bool TFGraphDialect::IsPack(TFOp op) const { 596 StringAttr op_name = op->getName().getIdentifier(); 597 return op_name == pack_; 598 } 599 IsPad(TFOp op) const600bool TFGraphDialect::IsPad(TFOp op) const { 601 StringAttr op_name = op->getName().getIdentifier(); 602 return op_name == pad_ || op_name == pad_v2_; 603 } 604 IsPartitionedCall(TFOp op) const605bool TFGraphDialect::IsPartitionedCall(TFOp op) const { 606 StringAttr op_name = op->getName().getIdentifier(); 607 return op_name == partitioned_call_; 608 } 609 IsPlaceholder(TFOp op) const610bool TFGraphDialect::IsPlaceholder(TFOp op) const { 611 StringAttr op_name = op->getName().getIdentifier(); 612 return op_name == placeholder_ || op_name == placeholder_v2_ || 613 op_name == placeholder_with_default_; 614 } 615 IsPolygamma(TFOp op) const616bool TFGraphDialect::IsPolygamma(TFOp op) const { 617 StringAttr op_name = op->getName().getIdentifier(); 618 return op_name == poly_gamma_; 619 } 620 IsPow(TFOp op) const621bool TFGraphDialect::IsPow(TFOp op) const { 622 StringAttr op_name = op->getName().getIdentifier(); 623 return op_name == pow_; 624 } 625 IsPrint(TFOp op) const626bool TFGraphDialect::IsPrint(TFOp op) const { 627 StringAttr op_name = op->getName().getIdentifier(); 628 return op_name == print_ || op_name == print_v2_; 629 } 630 IsProd(TFOp op) const631bool TFGraphDialect::IsProd(TFOp op) const { 632 StringAttr op_name = op->getName().getIdentifier(); 633 return op_name == prod_; 634 } 635 IsQuantizedMatMul(TFOp op) const636bool TFGraphDialect::IsQuantizedMatMul(TFOp op) const { 637 StringAttr op_name = op->getName().getIdentifier(); 638 return op_name == quantized_matmul_ || op_name == quantized_matmul_v2_; 639 } 640 IsQueue(TFOp op) const641bool TFGraphDialect::IsQueue(TFOp op) const { 642 StringAttr op_name = op->getName().getIdentifier(); 643 return op_name == random_shuffle_queue_v2_ || op_name == fifo_queue_v2_ || 644 op_name == padding_fifo_queue_v2_ || op_name == priority_queue_v2_; 645 } 646 IsRandomShuffle(TFOp op) const647bool TFGraphDialect::IsRandomShuffle(TFOp op) const { 648 StringAttr op_name = op->getName().getIdentifier(); 649 return op_name == random_shuffle_; 650 } 651 IsRank(TFOp op) const652bool TFGraphDialect::IsRank(TFOp op) const { 653 StringAttr op_name = op->getName().getIdentifier(); 654 return op_name == rank_; 655 } 656 IsReadVariableOp(TFOp op) const657bool TFGraphDialect::IsReadVariableOp(TFOp op) const { 658 StringAttr op_name = op->getName().getIdentifier(); 659 return op_name == read_variable_op_; 660 } 661 IsReadVariablesOp(TFOp op) const662bool TFGraphDialect::IsReadVariablesOp(TFOp op) const { 663 StringAttr op_name = op->getName().getIdentifier(); 664 return op_name == read_variables_op_; 665 } 666 IsReal(TFOp op) const667bool TFGraphDialect::IsReal(TFOp op) const { 668 StringAttr op_name = op->getName().getIdentifier(); 669 return op_name == real_; 670 } 671 IsRealDiv(TFOp op) const672bool TFGraphDialect::IsRealDiv(TFOp op) const { 673 StringAttr op_name = op->getName().getIdentifier(); 674 return op_name == real_div_; 675 } 676 IsReciprocalGrad(TFOp op) const677bool TFGraphDialect::IsReciprocalGrad(TFOp op) const { 678 StringAttr op_name = op->getName().getIdentifier(); 679 return op_name == reciprocal_grad_; 680 } 681 IsRecv(TFOp op) const682bool TFGraphDialect::IsRecv(TFOp op) const { 683 StringAttr op_name = op->getName().getIdentifier(); 684 return op_name == recv_ || op_name == host_recv_; 685 } 686 IsReduction(TFOp op) const687bool TFGraphDialect::IsReduction(TFOp op) const { 688 return IsSum(op) || IsProd(op) || IsMin(op) || IsMax(op) || IsMean(op) || 689 IsAny(op) || IsAll(op); 690 } 691 IsRelu(TFOp op) const692bool TFGraphDialect::IsRelu(TFOp op) const { 693 StringAttr op_name = op->getName().getIdentifier(); 694 return op_name == relu_; 695 } 696 IsRelu6(TFOp op) const697bool TFGraphDialect::IsRelu6(TFOp op) const { 698 StringAttr op_name = op->getName().getIdentifier(); 699 return op_name == relu6_; 700 } 701 IsReluGrad(TFOp op) const702bool TFGraphDialect::IsReluGrad(TFOp op) const { 703 StringAttr op_name = op->getName().getIdentifier(); 704 return op_name == relu_grad_; 705 } 706 IsRelu6Grad(TFOp op) const707bool TFGraphDialect::IsRelu6Grad(TFOp op) const { 708 StringAttr op_name = op->getName().getIdentifier(); 709 return op_name == relu6_grad_; 710 } 711 IsReshape(TFOp op) const712bool TFGraphDialect::IsReshape(TFOp op) const { 713 StringAttr op_name = op->getName().getIdentifier(); 714 return op_name == reshape_; 715 } 716 IsRestore(TFOp op) const717bool TFGraphDialect::IsRestore(TFOp op) const { 718 StringAttr op_name = op->getName().getIdentifier(); 719 return op_name == restore_ || op_name == restore_v2_ || 720 op_name == restore_slice_; 721 } 722 IsReturn(TFOp op) const723bool TFGraphDialect::IsReturn(TFOp op) const { 724 StringAttr op_name = op->getName().getIdentifier(); 725 return op_name == return_; 726 } 727 IsRetval(TFOp op) const728bool TFGraphDialect::IsRetval(TFOp op) const { 729 StringAttr op_name = op->getName().getIdentifier(); 730 return op_name == retval_ || op_name == device_retval_; 731 } 732 IsReverse(TFOp op) const733bool TFGraphDialect::IsReverse(TFOp op) const { 734 StringAttr op_name = op->getName().getIdentifier(); 735 return op_name == reverse_ || IsReverseV2(op); 736 } 737 IsReverseV2(TFOp op) const738bool TFGraphDialect::IsReverseV2(TFOp op) const { 739 StringAttr op_name = op->getName().getIdentifier(); 740 return op_name == reverse_v2_; 741 } 742 IsRsqrt(TFOp op) const743bool TFGraphDialect::IsRsqrt(TFOp op) const { 744 StringAttr op_name = op->getName().getIdentifier(); 745 return op_name == rsqrt_; 746 } 747 IsRsqrtGrad(TFOp op) const748bool TFGraphDialect::IsRsqrtGrad(TFOp op) const { 749 StringAttr op_name = op->getName().getIdentifier(); 750 return op_name == rsqrt_grad_; 751 } 752 IsSelect(TFOp op) const753bool TFGraphDialect::IsSelect(TFOp op) const { 754 StringAttr op_name = op->getName().getIdentifier(); 755 return op_name == select_ || op_name == select_v2_; 756 } 757 IsSeluGrad(TFOp op) const758bool TFGraphDialect::IsSeluGrad(TFOp op) const { 759 StringAttr op_name = op->getName().getIdentifier(); 760 return op_name == selu_grad_; 761 } 762 IsSend(TFOp op) const763bool TFGraphDialect::IsSend(TFOp op) const { 764 StringAttr op_name = op->getName().getIdentifier(); 765 return op_name == send_ || op_name == host_send_; 766 } 767 IsShape(TFOp op) const768bool TFGraphDialect::IsShape(TFOp op) const { 769 StringAttr op_name = op->getName().getIdentifier(); 770 return op_name == shape_; 771 } 772 IsShapeN(TFOp op) const773bool TFGraphDialect::IsShapeN(TFOp op) const { 774 StringAttr op_name = op->getName().getIdentifier(); 775 return op_name == shape_n_; 776 } 777 IsShuffle(TFOp op) const778bool TFGraphDialect::IsShuffle(TFOp op) const { 779 StringAttr op_name = op->getName().getIdentifier(); 780 return op_name == shuffle_; 781 } 782 IsSigmoid(TFOp op) const783bool TFGraphDialect::IsSigmoid(TFOp op) const { 784 StringAttr op_name = op->getName().getIdentifier(); 785 return op_name == sigmoid_; 786 } 787 IsSigmoidGrad(TFOp op) const788bool TFGraphDialect::IsSigmoidGrad(TFOp op) const { 789 StringAttr op_name = op->getName().getIdentifier(); 790 return op_name == sigmoid_grad_; 791 } 792 IsSize(TFOp op) const793bool TFGraphDialect::IsSize(TFOp op) const { 794 StringAttr op_name = op->getName().getIdentifier(); 795 return op_name == size_; 796 } 797 IsSlice(TFOp op) const798bool TFGraphDialect::IsSlice(TFOp op) const { 799 StringAttr op_name = op->getName().getIdentifier(); 800 return op_name == slice_; 801 } 802 IsSnapshot(TFOp op) const803bool TFGraphDialect::IsSnapshot(TFOp op) const { 804 StringAttr op_name = op->getName().getIdentifier(); 805 return op_name == snapshot_; 806 } 807 IsSoftmax(TFOp op) const808bool TFGraphDialect::IsSoftmax(TFOp op) const { 809 StringAttr op_name = op->getName().getIdentifier(); 810 return op_name == softmax_; 811 } 812 IsSoftplusGrad(TFOp op) const813bool TFGraphDialect::IsSoftplusGrad(TFOp op) const { 814 StringAttr op_name = op->getName().getIdentifier(); 815 return op_name == softplus_grad_; 816 } 817 IsSoftsignGrad(TFOp op) const818bool TFGraphDialect::IsSoftsignGrad(TFOp op) const { 819 StringAttr op_name = op->getName().getIdentifier(); 820 return op_name == softsign_grad_; 821 } 822 IsSplit(TFOp op) const823bool TFGraphDialect::IsSplit(TFOp op) const { 824 StringAttr op_name = op->getName().getIdentifier(); 825 return op_name == split_; 826 } 827 IsSplitV(TFOp op) const828bool TFGraphDialect::IsSplitV(TFOp op) const { 829 StringAttr op_name = op->getName().getIdentifier(); 830 return op_name == split_v_; 831 } 832 IsSqrt(TFOp op) const833bool TFGraphDialect::IsSqrt(TFOp op) const { 834 StringAttr op_name = op->getName().getIdentifier(); 835 return op_name == sqrt_; 836 } 837 IsSqrtGrad(TFOp op) const838bool TFGraphDialect::IsSqrtGrad(TFOp op) const { 839 StringAttr op_name = op->getName().getIdentifier(); 840 return op_name == sqrt_grad_; 841 } 842 IsSquare(TFOp op) const843bool TFGraphDialect::IsSquare(TFOp op) const { 844 StringAttr op_name = op->getName().getIdentifier(); 845 return op_name == square_; 846 } 847 IsSquaredDifference(TFOp op) const848bool TFGraphDialect::IsSquaredDifference(TFOp op) const { 849 StringAttr op_name = op->getName().getIdentifier(); 850 return op_name == squared_difference_; 851 } 852 IsSqueeze(TFOp op) const853bool TFGraphDialect::IsSqueeze(TFOp op) const { 854 StringAttr op_name = op->getName().getIdentifier(); 855 return op_name == squeeze_; 856 } 857 IsStackOp(TFOp op) const858bool TFGraphDialect::IsStackOp(TFOp op) const { 859 StringAttr op_name = op->getName().getIdentifier(); 860 return op_name == stack_ || op_name == stack_v2_; 861 } 862 IsStackCloseOp(TFOp op) const863bool TFGraphDialect::IsStackCloseOp(TFOp op) const { 864 StringAttr op_name = op->getName().getIdentifier(); 865 return op_name == stack_close_ || op_name == stack_close_v2_; 866 } 867 IsStackPushOp(TFOp op) const868bool TFGraphDialect::IsStackPushOp(TFOp op) const { 869 StringAttr op_name = op->getName().getIdentifier(); 870 return op_name == stack_push_ || op_name == stack_push_v2_; 871 } 872 IsStackPopOp(TFOp op) const873bool TFGraphDialect::IsStackPopOp(TFOp op) const { 874 StringAttr op_name = op->getName().getIdentifier(); 875 return op_name == stack_pop_ || op_name == stack_pop_v2_; 876 } 877 IsStatefulPartitionedCall(TFOp op) const878bool TFGraphDialect::IsStatefulPartitionedCall(TFOp op) const { 879 StringAttr op_name = op->getName().getIdentifier(); 880 return op_name == stateful_partitioned_call_; 881 } 882 IsStopGradient(TFOp op) const883bool TFGraphDialect::IsStopGradient(TFOp op) const { 884 StringAttr op_name = op->getName().getIdentifier(); 885 return op_name == stop_gradient_ || op_name == prevent_gradient_; 886 } 887 IsStridedSlice(TFOp op) const888bool TFGraphDialect::IsStridedSlice(TFOp op) const { 889 StringAttr op_name = op->getName().getIdentifier(); 890 return op_name == strided_slice_; 891 } 892 IsStridedSliceGrad(TFOp op) const893bool TFGraphDialect::IsStridedSliceGrad(TFOp op) const { 894 StringAttr op_name = op->getName().getIdentifier(); 895 return op_name == strided_slice_grad_; 896 } 897 IsStringToHashBucketFast(TFOp op) const898bool TFGraphDialect::IsStringToHashBucketFast(TFOp op) const { 899 StringAttr op_name = op->getName().getIdentifier(); 900 return op_name == string_to_hashbucket_fast_; 901 } 902 IsSub(TFOp op) const903bool TFGraphDialect::IsSub(TFOp op) const { 904 StringAttr op_name = op->getName().getIdentifier(); 905 return op_name == sub_; 906 } 907 IsSum(TFOp op) const908bool TFGraphDialect::IsSum(TFOp op) const { 909 StringAttr op_name = op->getName().getIdentifier(); 910 return op_name == sum_; 911 } 912 IsSwitch(TFOp op) const913bool TFGraphDialect::IsSwitch(TFOp op) const { 914 StringAttr op_name = op->getName().getIdentifier(); 915 return op_name == switch_ || op_name == switch_n_ || op_name == ref_switch_; 916 } 917 IsSymbolicGradient(TFOp op) const918bool TFGraphDialect::IsSymbolicGradient(TFOp op) const { 919 StringAttr op_name = op->getName().getIdentifier(); 920 return op_name == symbolic_gradient_; 921 } 922 IsTanh(TFOp op) const923bool TFGraphDialect::IsTanh(TFOp op) const { 924 StringAttr op_name = op->getName().getIdentifier(); 925 return op_name == tanh_; 926 } 927 IsTanhGrad(TFOp op) const928bool TFGraphDialect::IsTanhGrad(TFOp op) const { 929 StringAttr op_name = op->getName().getIdentifier(); 930 return op_name == tanh_grad_; 931 } 932 IsTile(TFOp op) const933bool TFGraphDialect::IsTile(TFOp op) const { 934 StringAttr op_name = op->getName().getIdentifier(); 935 return op_name == tile_; 936 } 937 IsTranspose(TFOp op) const938bool TFGraphDialect::IsTranspose(TFOp op) const { 939 StringAttr op_name = op->getName().getIdentifier(); 940 return op_name == transpose_; 941 } 942 IsTruncateDiv(TFOp op) const943bool TFGraphDialect::IsTruncateDiv(TFOp op) const { 944 StringAttr op_name = op->getName().getIdentifier(); 945 return op_name == truncate_div_; 946 } 947 IsTruncateMod(TFOp op) const948bool TFGraphDialect::IsTruncateMod(TFOp op) const { 949 StringAttr op_name = op->getName().getIdentifier(); 950 return op_name == truncate_mod_; 951 } 952 IsUnique(TFOp op) const953bool TFGraphDialect::IsUnique(TFOp op) const { 954 StringAttr op_name = op->getName().getIdentifier(); 955 return op_name == unique_ || op_name == unique_v2_; 956 } 957 IsUnpack(TFOp op) const958bool TFGraphDialect::IsUnpack(TFOp op) const { 959 StringAttr op_name = op->getName().getIdentifier(); 960 return op_name == unpack_; 961 } 962 IsVariable(TFOp op) const963bool TFGraphDialect::IsVariable(TFOp op) const { 964 StringAttr op_name = op->getName().getIdentifier(); 965 return op_name == variable_ || op_name == variable_v2_ || 966 op_name == auto_reload_variable_ || op_name == var_handle_op_ || 967 op_name == var_handles_op_ || IsReadVariableOp(op) || 968 IsReadVariablesOp(op); 969 } 970 IsWhile(TFOp op) const971bool TFGraphDialect::IsWhile(TFOp op) const { 972 StringAttr op_name = op->getName().getIdentifier(); 973 return op_name == while_ || op_name == stateless_while_; 974 } 975 IsXdivy(TFOp op) const976bool TFGraphDialect::IsXdivy(TFOp op) const { 977 StringAttr op_name = op->getName().getIdentifier(); 978 return op_name == xdivy_; 979 } 980 IsZerosLike(TFOp op) const981bool TFGraphDialect::IsZerosLike(TFOp op) const { 982 StringAttr op_name = op->getName().getIdentifier(); 983 return op_name == zeros_like_; 984 } 985 IsZeta(TFOp op) const986bool TFGraphDialect::IsZeta(TFOp op) const { 987 StringAttr op_name = op->getName().getIdentifier(); 988 return op_name == zeta_; 989 } 990 991 } // namespace tfg 992 } // namespace mlir 993