xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/op_types.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/op_types.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/utils.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/gtl/flatset.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 
IsAdd(const NodeDef & node)30 bool IsAdd(const NodeDef& node) {
31   if (node.op() == "AddV2") {
32     return true;
33   }
34   if (node.op() == "Add") {
35     DataType type = node.attr().at("T").type();
36     return type != DT_STRING;
37   }
38   return false;
39 }
40 
IsAddN(const NodeDef & node)41 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
42 
IsAll(const NodeDef & node)43 bool IsAll(const NodeDef& node) { return node.op() == "All"; }
44 
IsAngle(const NodeDef & node)45 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
46 
IsAny(const NodeDef & node)47 bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
48 
IsAnyDiv(const NodeDef & node)49 bool IsAnyDiv(const NodeDef& node) {
50   return node.op() == "RealDiv" || node.op() == "Div" || node.op() == "Xdivy" ||
51          node.op() == "FloorDiv" || node.op() == "TruncateDiv";
52 }
53 
IsAnyBatchMatMul(const NodeDef & node)54 bool IsAnyBatchMatMul(const NodeDef& node) {
55   return node.op() == "BatchMatMul" || node.op() == "BatchMatMulV2";
56 }
57 
IsAnyMatMul(const NodeDef & node)58 bool IsAnyMatMul(const NodeDef& node) {
59   return node.op() == "MatMul" || node.op() == "SparseMatMul" ||
60          IsAnyBatchMatMul(node) || IsQuantizedMatMul(node);
61 }
62 
IsAnyMax(const NodeDef & node)63 bool IsAnyMax(const NodeDef& node) {
64   const auto& op = node.op();
65   return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
66 }
67 
IsAnyMaxPool(const NodeDef & node)68 bool IsAnyMaxPool(const NodeDef& node) {
69   const auto& op = node.op();
70   return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
71          op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
72 }
73 
IsAnyMin(const NodeDef & node)74 bool IsAnyMin(const NodeDef& node) {
75   const auto& op = node.op();
76   return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
77 }
78 
IsAnySparseSegmentReduction(const NodeDef & node)79 bool IsAnySparseSegmentReduction(const NodeDef& node) {
80   const auto& op = node.op();
81   return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" ||
82          op == "SparseSegmentMean" ||
83          op == "SparseSegmentMeanWithNumSegments" ||
84          op == "SparseSegmentSqrtN" ||
85          op == "SparseSegmentSqrtNWithNumSegments";
86 }
87 
IsApproximateEqual(const NodeDef & node)88 bool IsApproximateEqual(const NodeDef& node) {
89   return node.op() == "ApproximateEqual";
90 }
91 
IsArg(const NodeDef & node)92 bool IsArg(const NodeDef& node) {
93   return node.op() == "_Arg" || node.op() == "_DeviceArg";
94 }
95 
IsArgMax(const NodeDef & node)96 bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
97 
IsArgMin(const NodeDef & node)98 bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
99 
IsAvgPoolGrad(const NodeDef & node)100 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
101 
IsAssign(const NodeDef & node)102 bool IsAssign(const NodeDef& node) {
103   return node.op() == "Assign" || node.op() == "AssignVariableOp";
104 }
105 
IsAssert(const NodeDef & node)106 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
107 
IsAsString(const NodeDef & node)108 bool IsAsString(const NodeDef& node) { return node.op() == "AsString"; }
109 
IsAtan2(const NodeDef & node)110 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
111 
IsBetainc(const NodeDef & node)112 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
113 
IsBiasAdd(const NodeDef & node)114 bool IsBiasAdd(const NodeDef& node) {
115   return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
116 }
117 
IsBiasAddV2(const NodeDef & node)118 bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd"; }
119 
IsBiasAddGrad(const NodeDef & node)120 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
121 
IsBitcast(const NodeDef & node)122 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
123 
IsBroadcastTo(const NodeDef & node)124 bool IsBroadcastTo(const NodeDef& node) { return node.op() == "BroadcastTo"; }
125 
IsCast(const NodeDef & node)126 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
127 
IsCastLike(const NodeDef & node)128 bool IsCastLike(const NodeDef& node) {
129   static const gtl::FlatSet<string>* const kCastLikeOps =
130       CHECK_NOTNULL((new gtl::FlatSet<string>{
131           "Angle", "Bucketize", "Cast", "Dequantize", "HistogramFixedWidth",
132           "Imag", "IsFinite", "IsInf", "IsNan", "Quantize",
133           "QuantizeDownAndShrinkRange", "QuantizeV2", "QuantizedInstanceNorm",
134           "QuantizedRelu", "QuantizedRelu6", "QuantizedReluX", "Real",
135           "Requantize"}));
136   return kCastLikeOps->count(node.op()) > 0;
137 }
138 
IsCheckNumerics(const NodeDef & node)139 bool IsCheckNumerics(const NodeDef& node) {
140   return node.op() == "CheckNumerics";
141 }
142 
IsCollective(const NodeDef & node)143 bool IsCollective(const NodeDef& node) {
144   return node.op() == "CollectiveReduce" ||
145          node.op() == "CollectiveBcastSend" ||
146          node.op() == "CollectiveBcastRecv";
147 }
148 
IsComplex(const NodeDef & node)149 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
150 
IsComplexAbs(const NodeDef & node)151 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
152 
IsConcat(const NodeDef & node)153 bool IsConcat(const NodeDef& node) {
154   return node.op() == "Concat" || node.op() == "ConcatV2";
155 }
156 
IsConcatOffset(const NodeDef & node)157 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
158 
IsConstant(const NodeDef & node)159 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
160 
IsConj(const NodeDef & node)161 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
162 
IsConjugateTranspose(const NodeDef & node)163 bool IsConjugateTranspose(const NodeDef& node) {
164   return node.op() == "ConjugateTranspose";
165 }
166 
IsControlFlow(const NodeDef & node)167 bool IsControlFlow(const NodeDef& node) {
168   // clang-format off
169   return node.op() == "ControlTrigger" ||
170          node.op() == "Enter" ||
171          node.op() == "Exit" ||
172          node.op() == "LoopCond" ||
173          node.op() == "Merge" ||
174          node.op() == "_XlaMerge" ||
175          node.op() == "NextIteration" ||
176          node.op() == "Switch" ||
177          node.op() == "_SwitchN";
178   // clang-format on
179 }
180 
IsConv2D(const NodeDef & node)181 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
182 
IsConv2DBackpropFilter(const NodeDef & node)183 bool IsConv2DBackpropFilter(const NodeDef& node) {
184   return node.op() == "Conv2DBackpropFilter";
185 }
186 
IsConv2DBackpropInput(const NodeDef & node)187 bool IsConv2DBackpropInput(const NodeDef& node) {
188   return node.op() == "Conv2DBackpropInput";
189 }
190 
IsConv3D(const NodeDef & node)191 bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
192 
IsConv3DBackpropFilterV2(const NodeDef & node)193 bool IsConv3DBackpropFilterV2(const NodeDef& node) {
194   return node.op() == "Conv3DBackpropFilterV2";
195 }
196 
IsConv3DBackpropInputV2(const NodeDef & node)197 bool IsConv3DBackpropInputV2(const NodeDef& node) {
198   return node.op() == "Conv3DBackpropInputV2";
199 }
200 
IsDepthwiseConv2dNative(const NodeDef & node)201 bool IsDepthwiseConv2dNative(const NodeDef& node) {
202   return node.op() == "DepthwiseConv2dNative";
203 }
204 
IsDepthwiseConv2dNativeBackpropFilter(const NodeDef & node)205 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
206   return node.op() == "DepthwiseConv2dNativeBackpropFilter";
207 }
208 
IsDepthwiseConv2dNativeBackpropInput(const NodeDef & node)209 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
210   return node.op() == "DepthwiseConv2dNativeBackpropInput";
211 }
212 
IsDequeueOp(const NodeDef & node)213 bool IsDequeueOp(const NodeDef& node) {
214   const auto& op = node.op();
215   return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
216          op == "QueueDequeueV2" || op == "QueueDequeue" ||
217          op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
218 }
219 
IsDiv(const NodeDef & node)220 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
221 
IsDivNoNan(const NodeDef & node)222 bool IsDivNoNan(const NodeDef& node) { return node.op() == "DivNoNan"; }
223 
224 // Returns true if node represents a unary elementwise function that is
225 // monotonic. If *is_non_decreasing is true, the function is non-decreasing,
226 // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
227 // e.g. inv.
IsElementWiseMonotonic(const NodeDef & node,bool * is_non_decreasing)228 bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
229   static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
230       CHECK_NOTNULL((new gtl::FlatSet<string>{
231           "Acosh", "Asin", "Asinh",    "Atan",     "Atanh", "Ceil",
232           "Elu",   "Erf",  "Exp",      "Expm1",    "Floor", "Log",
233           "Log1p", "Relu", "Relu6",    "Rint",     "Selu",  "Sigmoid",
234           "Sign",  "Sinh", "Softsign", "Softplus", "Sqrt",  "Tanh",
235       }));
236   static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
237       CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"}));
238   if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
239     if (is_non_decreasing) {
240       *is_non_decreasing = true;
241     }
242     return true;
243   } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
244     if (is_non_decreasing) {
245       *is_non_decreasing = false;
246     }
247     return true;
248   }
249   return false;
250 }
251 
IsElu(const NodeDef & node)252 bool IsElu(const NodeDef& node) { return node.op() == "Elu"; }
253 
IsEluGrad(const NodeDef & node)254 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
255 
IsQuantizationEmulation(const NodeDef & node)256 bool IsQuantizationEmulation(const NodeDef& node) {
257   const auto& op = node.op();
258   return absl::StartsWith(op, "QuantizeAndDequantize") ||
259          absl::StartsWith(op, "FakeQuantWithMinMax");
260 }
261 
IsEnter(const NodeDef & node)262 bool IsEnter(const NodeDef& node) {
263   const auto& op = node.op();
264   return op == "Enter" || op == "RefEnter";
265 }
266 
IsEqual(const NodeDef & node)267 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
268 
IsExit(const NodeDef & node)269 bool IsExit(const NodeDef& node) {
270   const auto& op = node.op();
271   return op == "Exit" || op == "RefExit";
272 }
273 
IsExp(const NodeDef & node)274 bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
275 
IsFakeParam(const NodeDef & node)276 bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
277 
IsFill(const NodeDef & node)278 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
279 
IsFloorDiv(const NodeDef & node)280 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
281 
IsFloorMod(const NodeDef & node)282 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
283 
IsFusedBatchNorm(const NodeDef & node)284 bool IsFusedBatchNorm(const NodeDef& node) {
285   const auto& op = node.op();
286   return op == "FusedBatchNorm" || op == "FusedBatchNormV2" ||
287          op == "FusedBatchNormV3";
288 }
289 
IsFusedBatchNormEx(const NodeDef & node)290 bool IsFusedBatchNormEx(const NodeDef& node) {
291   return node.op() == "_FusedBatchNormEx";
292 }
293 
IsFusedBatchNormGrad(const NodeDef & node)294 bool IsFusedBatchNormGrad(const NodeDef& node) {
295   const auto& op = node.op();
296   return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2" ||
297          op == "FusedBatchNormGradV3";
298 }
299 
IsGather(const NodeDef & node)300 bool IsGather(const NodeDef& node) {
301   const auto& op = node.op();
302   return op == "Gather" || op == "GatherV2" || op == "ResourceGather";
303 }
304 
IsGreater(const NodeDef & node)305 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
306 
IsGreaterEqual(const NodeDef & node)307 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
308 
IsHostConstant(const NodeDef & node)309 bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; }
310 
IsHistogramSummary(const NodeDef & node)311 bool IsHistogramSummary(const NodeDef& node) {
312   return node.op() == "HistogramSummary";
313 }
314 
IsIdentity(const NodeDef & node)315 bool IsIdentity(const NodeDef& node) {
316   const auto& op = node.op();
317   return op == "Identity" || op == "RefIdentity";
318 }
319 
IsIdentityN(const NodeDef & node)320 bool IsIdentityN(const NodeDef& node) {
321   const auto& op = node.op();
322   return op == "IdentityN";
323 }
324 
IsIdentityNSingleInput(const NodeDef & node)325 bool IsIdentityNSingleInput(const NodeDef& node) {
326   return IsIdentityN(node) && node.attr().count("T") != 0 &&
327          node.attr().at("T").list().type_size() == 1;
328 }
329 
IsIf(const NodeDef & node)330 bool IsIf(const NodeDef& node) {
331   const auto& op = node.op();
332   return op == "If" || op == "StatelessIf";
333 }
334 
IsIgamma(const NodeDef & node)335 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
336 
IsIgammac(const NodeDef & node)337 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
338 
IsImag(const NodeDef & node)339 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
340 
IsImmutableConst(const NodeDef & node)341 bool IsImmutableConst(const NodeDef& node) {
342   return node.op() == "ImmutableConst";
343 }
344 
IsInvGrad(const NodeDef & node)345 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
346 
IsLeakyRelu(const NodeDef & node)347 bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }
348 
IsLeakyReluGrad(const NodeDef & node)349 bool IsLeakyReluGrad(const NodeDef& node) {
350   return node.op() == "LeakyReluGrad";
351 }
352 
IsLess(const NodeDef & node)353 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
354 
IsLessEqual(const NodeDef & node)355 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
356 
IsLog(const NodeDef & node)357 bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
358 
IsLogicalAnd(const NodeDef & node)359 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
360 
IsLogicalNot(const NodeDef & node)361 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
362 
IsLogicalOr(const NodeDef & node)363 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
364 
IsLoopCond(const NodeDef & node)365 bool IsLoopCond(const NodeDef& node) { return node.op() == "LoopCond"; }
366 
IsMatMul(const NodeDef & node)367 bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
368 
IsMax(const NodeDef & node)369 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
370 
IsMaximum(const NodeDef & node)371 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
372 
IsMaxPoolGrad(const NodeDef & node)373 bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
374 
IsMean(const NodeDef & node)375 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
376 
IsMerge(const NodeDef & node)377 bool IsMerge(const NodeDef& node) {
378   const auto& op = node.op();
379   return op == "Merge" || op == "RefMerge" || op == "_XlaMerge";
380 }
381 
IsMin(const NodeDef & node)382 bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
383 
IsMinimum(const NodeDef & node)384 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
385 
IsMirrorPad(const NodeDef & node)386 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
387 
IsMirrorPadGrad(const NodeDef & node)388 bool IsMirrorPadGrad(const NodeDef& node) {
389   return node.op() == "MirrorPadGrad";
390 }
391 
IsMod(const NodeDef & node)392 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
393 
IsMul(const NodeDef & node)394 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
IsMulNoNan(const NodeDef & node)395 bool IsMulNoNan(const NodeDef& node) { return node.op() == "MulNoNan"; }
IsAnyMul(const NodeDef & node)396 bool IsAnyMul(const NodeDef& node) { return IsMul(node) || IsMulNoNan(node); }
397 
IsNeg(const NodeDef & node)398 bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; }
399 
IsNoOp(const NodeDef & node)400 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
401 
IsNotEqual(const NodeDef & node)402 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
403 
IsNextIteration(const NodeDef & node)404 bool IsNextIteration(const NodeDef& node) {
405   const auto& op = node.op();
406   return op == "NextIteration" || op == "RefNextIteration";
407 }
408 
IsOnesLike(const NodeDef & node)409 bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
410 
IsPack(const NodeDef & node)411 bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
412 
IsPad(const NodeDef & node)413 bool IsPad(const NodeDef& node) {
414   const auto& op = node.op();
415   return op == "Pad" || op == "PadV2";
416 }
417 
IsPartitionedCall(const NodeDef & node)418 bool IsPartitionedCall(const NodeDef& node) {
419   return node.op() == "PartitionedCall";
420 }
421 
IsPlaceholder(const NodeDef & node)422 bool IsPlaceholder(const NodeDef& node) {
423   const auto& op = node.op();
424   return op == "Placeholder" || op == "PlaceholderV2" ||
425          op == "PlaceholderWithDefault";
426 }
427 
IsPolygamma(const NodeDef & node)428 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
429 
IsPow(const NodeDef & node)430 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
431 
IsPrint(const NodeDef & node)432 bool IsPrint(const NodeDef& node) {
433   return node.op() == "Print" || node.op() == "PrintV2";
434 }
435 
IsProd(const NodeDef & node)436 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
437 
IsQuantizedMatMul(const NodeDef & node)438 bool IsQuantizedMatMul(const NodeDef& node) {
439   return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
440 }
441 
IsQueue(const NodeDef & node)442 bool IsQueue(const NodeDef& node) {
443   return str_util::EndsWith(node.op(), "QueueV2");
444 }
445 
IsRandomShuffle(const NodeDef & node)446 bool IsRandomShuffle(const NodeDef& node) {
447   return node.op() == "RandomShuffle";
448 }
449 
IsRank(const NodeDef & node)450 bool IsRank(const NodeDef& node) { return node.op() == "Rank"; }
451 
IsReadVariableOp(const NodeDef & node)452 bool IsReadVariableOp(const NodeDef& node) {
453   return node.op() == "ReadVariableOp";
454 }
455 
IsReadVariablesOp(const NodeDef & node)456 bool IsReadVariablesOp(const NodeDef& node) {
457   return node.op() == "_ReadVariablesOp";
458 }
459 
IsReal(const NodeDef & node)460 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
461 
IsRealDiv(const NodeDef & node)462 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
463 
IsReciprocalGrad(const NodeDef & node)464 bool IsReciprocalGrad(const NodeDef& node) {
465   return node.op() == "ReciprocalGrad";
466 }
467 
IsRecv(const NodeDef & node)468 bool IsRecv(const NodeDef& node) {
469   return node.op() == "_Recv" || node.op() == "_HostRecv";
470 }
471 
IsReduction(const NodeDef & node)472 bool IsReduction(const NodeDef& node) {
473   const auto& op = node.op();
474   return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
475          op == "Mean" || op == "Any" || op == "All";
476 }
477 
IsRelu(const NodeDef & node)478 bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
479 
IsRelu6(const NodeDef & node)480 bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6"; }
481 
IsReluGrad(const NodeDef & node)482 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
483 
IsRelu6Grad(const NodeDef & node)484 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
485 
IsReshape(const NodeDef & node)486 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
487 
IsRestore(const NodeDef & node)488 bool IsRestore(const NodeDef& node) {
489   return (node.op() == "Restore" || node.op() == "RestoreV2" ||
490           node.op() == "RestoreSlice");
491 }
492 
IsRetval(const NodeDef & node)493 bool IsRetval(const NodeDef& node) {
494   return node.op() == "_Retval" || node.op() == "_DeviceRetval";
495 }
496 
IsReverse(const NodeDef & node)497 bool IsReverse(const NodeDef& node) {
498   return node.op() == "Reverse" || node.op() == "ReverseV2";
499 }
500 
IsReverseV2(const NodeDef & node)501 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
502 
IsRsqrt(const NodeDef & node)503 bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
504 
IsRsqrtGrad(const NodeDef & node)505 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
506 
IsSelect(const NodeDef & node)507 bool IsSelect(const NodeDef& node) {
508   return node.op() == "Select" || node.op() == "SelectV2";
509 }
510 
IsSeluGrad(const NodeDef & node)511 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
512 
IsSend(const NodeDef & node)513 bool IsSend(const NodeDef& node) {
514   return node.op() == "_Send" || node.op() == "_HostSend";
515 }
516 
IsShape(const NodeDef & node)517 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
518 
IsShapeN(const NodeDef & node)519 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
520 
IsShuffle(const NodeDef & node)521 bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
522 
IsSigmoid(const NodeDef & node)523 bool IsSigmoid(const NodeDef& node) { return node.op() == "Sigmoid"; }
524 
IsSigmoidGrad(const NodeDef & node)525 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
526 
IsSize(const NodeDef & node)527 bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
528 
IsSlice(const NodeDef & node)529 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
530 
IsSnapshot(const NodeDef & node)531 bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
532 
IsSoftmax(const NodeDef & node)533 bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
534 
IsSoftplusGrad(const NodeDef & node)535 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
536 
IsSoftsignGrad(const NodeDef & node)537 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
538 
IsSplit(const NodeDef & node)539 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
540 
IsSplitV(const NodeDef & node)541 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
542 
IsSqrt(const NodeDef & node)543 bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
544 
IsSqrtGrad(const NodeDef & node)545 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
546 
IsSquare(const NodeDef & node)547 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
548 
IsSquaredDifference(const NodeDef & node)549 bool IsSquaredDifference(const NodeDef& node) {
550   return node.op() == "SquaredDifference";
551 }
552 
IsSqueeze(const NodeDef & node)553 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
554 
IsStackOp(const NodeDef & node)555 bool IsStackOp(const NodeDef& node) {
556   return node.op() == "Stack" || node.op() == "StackV2";
557 }
IsStackCloseOp(const NodeDef & node)558 bool IsStackCloseOp(const NodeDef& node) {
559   return node.op() == "StackClose" || node.op() == "StackCloseV2";
560 }
IsStackPushOp(const NodeDef & node)561 bool IsStackPushOp(const NodeDef& node) {
562   return node.op() == "StackPush" || node.op() == "StackPushV2";
563 }
IsStackPopOp(const NodeDef & node)564 bool IsStackPopOp(const NodeDef& node) {
565   return node.op() == "StackPop" || node.op() == "StackPopV2";
566 }
567 
IsStatefulPartitionedCall(const NodeDef & node)568 bool IsStatefulPartitionedCall(const NodeDef& node) {
569   return node.op() == "StatefulPartitionedCall";
570 }
571 
IsStopGradient(const NodeDef & node)572 bool IsStopGradient(const NodeDef& node) {
573   const auto& op = node.op();
574   return op == "StopGradient" || op == "PreventGradient";
575 }
576 
IsStridedSlice(const NodeDef & node)577 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
578 
IsStridedSliceGrad(const NodeDef & node)579 bool IsStridedSliceGrad(const NodeDef& node) {
580   return node.op() == "StridedSliceGrad";
581 }
582 
IsStringToHashBucketFast(const NodeDef & node)583 bool IsStringToHashBucketFast(const NodeDef& node) {
584   return node.op() == "StringToHashBucketFast";
585 }
586 
IsSub(const NodeDef & node)587 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
588 
IsSum(const NodeDef & node)589 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
590 
IsSwitch(const NodeDef & node)591 bool IsSwitch(const NodeDef& node) {
592   const auto& op = node.op();
593   return op == "_SwitchN" || op == "Switch" || op == "RefSwitch";
594 }
595 
IsSymbolicGradient(const NodeDef & node)596 bool IsSymbolicGradient(const NodeDef& node) {
597   return node.op() == "SymbolicGradient";
598 }
599 
IsTanh(const NodeDef & node)600 bool IsTanh(const NodeDef& node) { return node.op() == "Tanh"; }
601 
IsTanhGrad(const NodeDef & node)602 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
603 
IsTensorArray(const NodeDef & node)604 bool IsTensorArray(const NodeDef& node) {
605   static const gtl::FlatSet<string>* const kTensorArrayOps =
606       CHECK_NOTNULL((new gtl::FlatSet<string>{
607           "TensorArray",
608           "TensorArrayV2",
609           "TensorArrayV3",
610           "TensorArrayGrad",
611           "TensorArrayGradV2",
612           "TensorArrayGradV3",
613           "TensorArrayGradWithShape",
614           "TensorArrayWrite",
615           "TensorArrayWriteV2",
616           "TensorArrayWriteV3",
617           "TensorArrayRead",
618           "TensorArrayReadV2",
619           "TensorArrayReadV3",
620           "TensorArrayConcat",
621           "TensorArrayConcatV2",
622           "TensorArrayConcatV3",
623           "TensorArraySplit",
624           "TensorArraySplitV2",
625           "TensorArraySplitV3",
626           "TensorArraySize",
627           "TensorArraySizeV2",
628           "TensorArraySizeV3",
629           "TensorArrayClose",
630           "TensorArrayCloseV2",
631           "TensorArrayCloseV3",
632       }));
633   return kTensorArrayOps->count(node.op()) > 0;
634 }
635 
IsTile(const NodeDef & node)636 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
637 
IsTranspose(const NodeDef & node)638 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
639 
IsTruncateDiv(const NodeDef & node)640 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
641 
IsTruncateMod(const NodeDef & node)642 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
643 
IsUnique(const NodeDef & node)644 bool IsUnique(const NodeDef& node) {
645   const auto& op = node.op();
646   return op == "Unique" || op == "UniqueV2";
647 }
648 
IsUnpack(const NodeDef & node)649 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
650 
IsVariable(const NodeDef & node)651 bool IsVariable(const NodeDef& node) {
652   const auto& op = node.op();
653   return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
654          op == "VarHandleOp" || op == "ReadVariableOp" ||
655          op == "_VarHandlesOp" || op == "_ReadVariablesOp";
656 }
657 
IsWhile(const NodeDef & node)658 bool IsWhile(const NodeDef& node) {
659   const auto& op = node.op();
660   return op == "While" || op == "StatelessWhile";
661 }
662 
IsXdivy(const NodeDef & node)663 bool IsXdivy(const NodeDef& node) { return node.op() == "Xdivy"; }
664 
IsZerosLike(const NodeDef & node)665 bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
666 
IsZeta(const NodeDef & node)667 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
668 
669 namespace {
GetBoolAttr(const NodeDef & node,const string & name)670 bool GetBoolAttr(const NodeDef& node, const string& name) {
671   return node.attr().count(name) > 0 && node.attr().at(name).b();
672 }
673 }  // namespace
674 
IsPersistent(const NodeDef & node)675 bool IsPersistent(const NodeDef& node) {
676   return IsConstant(node) || IsVariable(node) || IsHostConstant(node);
677 }
678 
HasRefInput(const NodeDef & node)679 bool HasRefInput(const NodeDef& node) {
680   const OpDef* op_def;
681   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
682   if (!status.ok()) {
683     return false;
684   }
685   // Nodes such as Assign or AssignAdd modify one of their inputs.
686   for (const auto& input : op_def->input_arg()) {
687     if (input.is_ref()) {
688       return true;
689     }
690   }
691   return false;
692 }
693 
IsDataset(const NodeDef & node)694 bool IsDataset(const NodeDef& node) {
695   const string& op = node.op();
696   // See `GetNodeClassForOp` in core/graph/graph.cc.
697   return op == "IteratorGetNext" || op == "IteratorGetNextSync" ||
698          op == "DatasetToSingleElement" || op == "ReduceDataset";
699 }
700 
IsStateful(const NodeDef node,const OpRegistryInterface * op_registry)701 bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) {
702   const OpDef* op_def = nullptr;
703   const string& op_name = node.op();
704   Status status = op_registry->LookUpOpDef(op_name, &op_def);
705   if (!status.ok()) {
706     LOG(WARNING) << "Failed to lookup OpDef for " << op_name
707                  << ". Error: " << status.error_message();
708     return false;
709   }
710   return op_def->is_stateful();
711 }
712 
IsStateful(const NodeDef node)713 bool IsStateful(const NodeDef node) {
714   return IsStateful(node, OpRegistry::Global());
715 }
716 
IsFreeOfSideEffect(const NodeDef & node,const OpRegistryInterface * op_registry)717 bool IsFreeOfSideEffect(const NodeDef& node,
718                         const OpRegistryInterface* op_registry) {
719   // Placeholders must be preserved to keep the graph feedable.
720   if (IsPlaceholder(node)) {
721     return false;
722   }
723   const OpDef* op_def = nullptr;
724   const string& op_name = node.op();
725   Status status = op_registry->LookUpOpDef(op_name, &op_def);
726   if (!status.ok()) {
727     return false;
728   }
729   if (op_def->is_stateful()) {
730     return false;
731   }
732   // Nodes such as Assign or AssignAdd modify one of their inputs.
733   for (const auto& input : op_def->input_arg()) {
734     if (input.is_ref()) {
735       return false;
736     }
737   }
738   // Queue ops modify the queue which is a side effect.
739   if (node.op().find("Queue") != string::npos) {
740     return false;
741   }
742   // Sending a tensor via a network is a side effect.
743   if (IsSend(node)) {
744     return false;
745   }
746   return !ModifiesInputsInPlace(node);
747 }
748 
IsFreeOfSideEffect(const NodeDef & node)749 bool IsFreeOfSideEffect(const NodeDef& node) {
750   return IsFreeOfSideEffect(node, OpRegistry::Global());
751 }
752 
ModifiesInputsInPlace(const NodeDef & node)753 bool ModifiesInputsInPlace(const NodeDef& node) {
754   // Some nodes do in-place updates on regular tensor inputs.
755   const string& op_name = node.op();
756 
757   // Ops that modify resource variables effectively modify one of their inputs.
758   if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
759       op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
760       op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
761       op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
762       op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
763     return false;
764   }
765 
766   string lower_op_name = op_name;
767   std::transform(lower_op_name.begin(), lower_op_name.end(),
768                  lower_op_name.begin(), ::tolower);
769   if (absl::StrContains(lower_op_name, "inplace")) {
770     return true;
771   }
772   return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
773 }
774 
ModifiesFrameInfo(const NodeDef & node)775 bool ModifiesFrameInfo(const NodeDef& node) {
776   return IsEnter(node) || IsExit(node) || IsNextIteration(node);
777 }
778 
779 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY)                      \
780   bool Is##PROPERTY_CAP(const NodeDef& node) {                             \
781     if (node.op() == "Add") {                                              \
782       /* Workaround for "Add" not being marked is_commutative and */       \
783       /* is_aggregate. (See cl/173915048). */                              \
784       const auto type = GetDataTypeFromAttr(node, "T");                    \
785       return type != DT_INVALID && type != DT_STRING;                      \
786     }                                                                      \
787     const OpDef* op_def = nullptr;                                         \
788     Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
789     return status.ok() && op_def->is_##PROPERTY();                         \
790   }
791 
OPDEF_PROPERTY_HELPER(Aggregate,aggregate)792 OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
793 OPDEF_PROPERTY_HELPER(Commutative, commutative)
794 
795 bool IsInvolution(const NodeDef& node) {
796   static const gtl::FlatSet<string>* const kInvolutionOps =
797       CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
798                                               "Neg", "LogicalNot"}));
799   return kInvolutionOps->count(node.op()) > 0;
800 }
801 
IsValueAndOrderAndShapePreserving(const NodeDef & node)802 bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
803   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
804     return true;
805   }
806   static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
807       CHECK_NOTNULL((new const gtl::FlatSet<string>{
808           "CheckNumerics",
809           "DebugGradientIdentity",
810           "DeepCopy",
811           "Enter",
812           "Exit",
813           "PreventGradient",
814           "Print",
815           "Snapshot",
816           "StopGradient",
817       }));
818   return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
819          IsIdentity(node);
820 }
821 
IsValueAndOrderPreserving(const NodeDef & node)822 bool IsValueAndOrderPreserving(const NodeDef& node) {
823   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
824     return true;
825   }
826   static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
827       CHECK_NOTNULL((new const gtl::FlatSet<string>{
828           "ExpandDims",
829           "Reshape",
830           "Squeeze",
831       }));
832   return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
833          IsValueAndOrderAndShapePreserving(node);
834 }
835 
IsValuePreserving(const NodeDef & node)836 bool IsValuePreserving(const NodeDef& node) {
837   static const gtl::FlatSet<string>* const kValuePreservingOps =
838       CHECK_NOTNULL((new gtl::FlatSet<string>{
839           "InvertPermutation",
840           "Reverse",
841           "ReverseV2",
842           "Roll",
843           "Transpose",
844           "DepthToSpace",
845           "SpaceToDepth",
846           "BatchToSpace",
847           "BatchToSpaceND",
848           "SpaceToBatch",
849           "SpaceToBatchND",
850       }));
851   return IsValueAndOrderPreserving(node) ||
852          kValuePreservingOps->count(node.op()) > 0;
853 }
854 
IsUnaryElementWise(const NodeDef & node)855 bool IsUnaryElementWise(const NodeDef& node) {
856   static const gtl::FlatSet<string>* const kElementWiseOps =
857       CHECK_NOTNULL((new gtl::FlatSet<string>{
858           "Abs",      "Acos",     "Acosh",      "Asin",       "Asinh",
859           "Atan",     "Atanh",    "Ceil",       "ComplexAbs", "Conj",
860           "Cos",      "Cosh",     "Digamma",    "Elu",        "Erf",
861           "Erfc",     "Exp",      "Expm1",      "Floor",      "Inv",
862           "Invert",   "Isinf",    "Isnan",      "Isfinite",   "Lgamma",
863           "Log",      "Log1p",    "LogicalNot", "Neg",        "Reciprocal",
864           "Relu",     "Relu6",    "Rint",       "Round",      "Selu",
865           "Rsqrt",    "Sigmoid",  "Sign",       "Sin",        "SinH",
866           "Softplus", "Softsign", "Sqrt",       "Square",     "Tan",
867           "Tanh",
868       }));
869   return kElementWiseOps->count(node.op()) > 0 ||
870          IsValueAndOrderAndShapePreserving(node);
871 }
872 
HasOpDef(const NodeDef & node)873 bool HasOpDef(const NodeDef& node) {
874   const OpDef* op_def = nullptr;
875   return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
876 }
877 
IsIdempotent(const NodeDef & node)878 bool IsIdempotent(const NodeDef& node) {
879   return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
880          !ModifiesFrameInfo(node);
881 }
882 
NeverForwardsInputs(const NodeDef & node)883 bool NeverForwardsInputs(const NodeDef& node) {
884   static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL(
885       (new gtl::FlatSet<string>{"ArgMax",
886                                 "ArgMin",
887                                 "AudioSpectrogram",
888                                 "AvgPool",
889                                 "BatchMatMul",
890                                 "BatchMatMulV2",
891                                 "BatchNormWithGlobalNormalization",
892                                 "BatchToSpace",
893                                 "BatchToSpaceND",
894                                 "Bincount",
895                                 "BroadcastArgs",
896                                 "BroadcastGradientArgs",
897                                 "Bucketize",
898                                 "CTCBeamSearchDecoder",
899                                 "CTCGreedyDecoder",
900                                 "CTCLoss",
901                                 "CompareAndBitpack",
902                                 "ComplexAbs",
903                                 "Concat",
904                                 "ConcatOffset",
905                                 "ConcatV2",
906                                 "Conv2D",
907                                 "Copy",
908                                 "CopyHost",
909                                 "Cross",
910                                 "CudnnRNN",
911                                 "CudnnRNNBackprop",
912                                 "CudnnRNNBackpropV2",
913                                 "CudnnRNNBackpropV3",
914                                 "CudnnRNNCanonicalToParams",
915                                 "CudnnRNNCanonicalToParamsV2",
916                                 "CudnnRNNParamsSize",
917                                 "CudnnRNNParamsToCanonical",
918                                 "CudnnRNNParamsToCanonicalV2",
919                                 "CudnnRNNV2",
920                                 "CudnnRNNV3",
921                                 "CumProd",
922                                 "CumSum",
923                                 "DebugNanCount",
924                                 "DebugNumericSummary",
925                                 "DecodeProtoV2",
926                                 "DecodeWav",
927                                 "DeepCopy",
928                                 "DepthToSpace",
929                                 "Dequantize",
930                                 "Diag",
931                                 "DiagPart",
932                                 "EditDistance",
933                                 "Empty",
934                                 "EncodeProtoV2",
935                                 "EncodeWav",
936                                 "ExtractImagePatches",
937                                 "ExtractVolumePatches",
938                                 "Fill",
939                                 "Gather",
940                                 "GatherNd",
941                                 "GatherV2",
942                                 "HistogramFixedWidth",
943                                 "InvertPermutation",
944                                 "IsInf",
945                                 "IsNan",
946                                 "Isfinite",
947                                 "LinSpace",
948                                 "LowerBound",
949                                 "MatMul",
950                                 "MatrixDiag",
951                                 "MatrixDiagPart",
952                                 "MatrixDiagPartV2",
953                                 "MatrixDiagV2",
954                                 "Mfcc",
955                                 "Multinomial",
956                                 "OneHot",
957                                 "Pack",
958                                 "ParameterizedTruncatedNormal",
959                                 "PopulationCount",
960                                 "RandomGamma",
961                                 "RandomPoisson",
962                                 "RandomPoissonV2",
963                                 "RandomStandardNormal",
964                                 "RandomUniform",
965                                 "RandomUniformInt",
966                                 "Range",
967                                 "Rank",
968                                 "RequantizationRange",
969                                 "Requantize",
970                                 "ReverseSequence",
971                                 "Shape",
972                                 "ShapeN",
973                                 "Size",
974                                 "SpaceToBatch",
975                                 "SpaceToBatchND",
976                                 "SpaceToDepth",
977                                 "SparseMatMul",
978                                 "Split",
979                                 "SplitV",
980                                 "TruncatedNormal",
981                                 "Unique",
982                                 "UniqueV2",
983                                 "UniqueWithCounts",
984                                 "UniqueWithCountsV2",
985                                 "Unpack",
986                                 "UnravelIndex",
987                                 "UpperBound",
988                                 "Where"}));
989   const string& op_name = node.op();
990   return kNonForwardingOps->count(op_name) > 0 ||
991          absl::StrContains(op_name, "Segment") ||
992          absl::StartsWith(op_name, "Quantize");
993 }
994 
IsXlaLaunch(const NodeDef & node)995 bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; }
996 
997 }  // namespace grappler
998 }  // end namespace tensorflow
999