1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/op_types.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/utils.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/gtl/flatset.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace tensorflow {
28 namespace grappler {
29
IsAdd(const NodeDef & node)30 bool IsAdd(const NodeDef& node) {
31 if (node.op() == "AddV2") {
32 return true;
33 }
34 if (node.op() == "Add") {
35 DataType type = node.attr().at("T").type();
36 return type != DT_STRING;
37 }
38 return false;
39 }
40
IsAddN(const NodeDef & node)41 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
42
IsAll(const NodeDef & node)43 bool IsAll(const NodeDef& node) { return node.op() == "All"; }
44
IsAngle(const NodeDef & node)45 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
46
IsAny(const NodeDef & node)47 bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
48
IsAnyDiv(const NodeDef & node)49 bool IsAnyDiv(const NodeDef& node) {
50 return node.op() == "RealDiv" || node.op() == "Div" || node.op() == "Xdivy" ||
51 node.op() == "FloorDiv" || node.op() == "TruncateDiv";
52 }
53
IsAnyBatchMatMul(const NodeDef & node)54 bool IsAnyBatchMatMul(const NodeDef& node) {
55 return node.op() == "BatchMatMul" || node.op() == "BatchMatMulV2";
56 }
57
IsAnyMatMul(const NodeDef & node)58 bool IsAnyMatMul(const NodeDef& node) {
59 return node.op() == "MatMul" || node.op() == "SparseMatMul" ||
60 IsAnyBatchMatMul(node) || IsQuantizedMatMul(node);
61 }
62
IsAnyMax(const NodeDef & node)63 bool IsAnyMax(const NodeDef& node) {
64 const auto& op = node.op();
65 return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
66 }
67
IsAnyMaxPool(const NodeDef & node)68 bool IsAnyMaxPool(const NodeDef& node) {
69 const auto& op = node.op();
70 return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
71 op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
72 }
73
IsAnyMin(const NodeDef & node)74 bool IsAnyMin(const NodeDef& node) {
75 const auto& op = node.op();
76 return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
77 }
78
IsAnySparseSegmentReduction(const NodeDef & node)79 bool IsAnySparseSegmentReduction(const NodeDef& node) {
80 const auto& op = node.op();
81 return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" ||
82 op == "SparseSegmentMean" ||
83 op == "SparseSegmentMeanWithNumSegments" ||
84 op == "SparseSegmentSqrtN" ||
85 op == "SparseSegmentSqrtNWithNumSegments";
86 }
87
IsApproximateEqual(const NodeDef & node)88 bool IsApproximateEqual(const NodeDef& node) {
89 return node.op() == "ApproximateEqual";
90 }
91
IsArg(const NodeDef & node)92 bool IsArg(const NodeDef& node) {
93 return node.op() == "_Arg" || node.op() == "_DeviceArg";
94 }
95
IsArgMax(const NodeDef & node)96 bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
97
IsArgMin(const NodeDef & node)98 bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
99
IsAvgPoolGrad(const NodeDef & node)100 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
101
IsAssign(const NodeDef & node)102 bool IsAssign(const NodeDef& node) {
103 return node.op() == "Assign" || node.op() == "AssignVariableOp";
104 }
105
IsAssert(const NodeDef & node)106 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
107
IsAsString(const NodeDef & node)108 bool IsAsString(const NodeDef& node) { return node.op() == "AsString"; }
109
IsAtan2(const NodeDef & node)110 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
111
IsBetainc(const NodeDef & node)112 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
113
IsBiasAdd(const NodeDef & node)114 bool IsBiasAdd(const NodeDef& node) {
115 return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
116 }
117
IsBiasAddV2(const NodeDef & node)118 bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd"; }
119
IsBiasAddGrad(const NodeDef & node)120 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
121
IsBitcast(const NodeDef & node)122 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
123
IsBroadcastTo(const NodeDef & node)124 bool IsBroadcastTo(const NodeDef& node) { return node.op() == "BroadcastTo"; }
125
IsCast(const NodeDef & node)126 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
127
IsCastLike(const NodeDef & node)128 bool IsCastLike(const NodeDef& node) {
129 static const gtl::FlatSet<string>* const kCastLikeOps =
130 CHECK_NOTNULL((new gtl::FlatSet<string>{
131 "Angle", "Bucketize", "Cast", "Dequantize", "HistogramFixedWidth",
132 "Imag", "IsFinite", "IsInf", "IsNan", "Quantize",
133 "QuantizeDownAndShrinkRange", "QuantizeV2", "QuantizedInstanceNorm",
134 "QuantizedRelu", "QuantizedRelu6", "QuantizedReluX", "Real",
135 "Requantize"}));
136 return kCastLikeOps->count(node.op()) > 0;
137 }
138
IsCheckNumerics(const NodeDef & node)139 bool IsCheckNumerics(const NodeDef& node) {
140 return node.op() == "CheckNumerics";
141 }
142
IsCollective(const NodeDef & node)143 bool IsCollective(const NodeDef& node) {
144 return node.op() == "CollectiveReduce" ||
145 node.op() == "CollectiveBcastSend" ||
146 node.op() == "CollectiveBcastRecv";
147 }
148
IsComplex(const NodeDef & node)149 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
150
IsComplexAbs(const NodeDef & node)151 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
152
IsConcat(const NodeDef & node)153 bool IsConcat(const NodeDef& node) {
154 return node.op() == "Concat" || node.op() == "ConcatV2";
155 }
156
IsConcatOffset(const NodeDef & node)157 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
158
IsConstant(const NodeDef & node)159 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
160
IsConj(const NodeDef & node)161 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
162
IsConjugateTranspose(const NodeDef & node)163 bool IsConjugateTranspose(const NodeDef& node) {
164 return node.op() == "ConjugateTranspose";
165 }
166
IsControlFlow(const NodeDef & node)167 bool IsControlFlow(const NodeDef& node) {
168 // clang-format off
169 return node.op() == "ControlTrigger" ||
170 node.op() == "Enter" ||
171 node.op() == "Exit" ||
172 node.op() == "LoopCond" ||
173 node.op() == "Merge" ||
174 node.op() == "_XlaMerge" ||
175 node.op() == "NextIteration" ||
176 node.op() == "Switch" ||
177 node.op() == "_SwitchN";
178 // clang-format on
179 }
180
IsConv2D(const NodeDef & node)181 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
182
IsConv2DBackpropFilter(const NodeDef & node)183 bool IsConv2DBackpropFilter(const NodeDef& node) {
184 return node.op() == "Conv2DBackpropFilter";
185 }
186
IsConv2DBackpropInput(const NodeDef & node)187 bool IsConv2DBackpropInput(const NodeDef& node) {
188 return node.op() == "Conv2DBackpropInput";
189 }
190
IsConv3D(const NodeDef & node)191 bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
192
IsConv3DBackpropFilterV2(const NodeDef & node)193 bool IsConv3DBackpropFilterV2(const NodeDef& node) {
194 return node.op() == "Conv3DBackpropFilterV2";
195 }
196
IsConv3DBackpropInputV2(const NodeDef & node)197 bool IsConv3DBackpropInputV2(const NodeDef& node) {
198 return node.op() == "Conv3DBackpropInputV2";
199 }
200
IsDepthwiseConv2dNative(const NodeDef & node)201 bool IsDepthwiseConv2dNative(const NodeDef& node) {
202 return node.op() == "DepthwiseConv2dNative";
203 }
204
IsDepthwiseConv2dNativeBackpropFilter(const NodeDef & node)205 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
206 return node.op() == "DepthwiseConv2dNativeBackpropFilter";
207 }
208
IsDepthwiseConv2dNativeBackpropInput(const NodeDef & node)209 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
210 return node.op() == "DepthwiseConv2dNativeBackpropInput";
211 }
212
IsDequeueOp(const NodeDef & node)213 bool IsDequeueOp(const NodeDef& node) {
214 const auto& op = node.op();
215 return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
216 op == "QueueDequeueV2" || op == "QueueDequeue" ||
217 op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
218 }
219
IsDiv(const NodeDef & node)220 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
221
IsDivNoNan(const NodeDef & node)222 bool IsDivNoNan(const NodeDef& node) { return node.op() == "DivNoNan"; }
223
224 // Returns true if node represents a unary elementwise function that is
225 // monotonic. If *is_non_decreasing is true, the function is non-decreasing,
226 // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
227 // e.g. inv.
IsElementWiseMonotonic(const NodeDef & node,bool * is_non_decreasing)228 bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
229 static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
230 CHECK_NOTNULL((new gtl::FlatSet<string>{
231 "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil",
232 "Elu", "Erf", "Exp", "Expm1", "Floor", "Log",
233 "Log1p", "Relu", "Relu6", "Rint", "Selu", "Sigmoid",
234 "Sign", "Sinh", "Softsign", "Softplus", "Sqrt", "Tanh",
235 }));
236 static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
237 CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"}));
238 if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
239 if (is_non_decreasing) {
240 *is_non_decreasing = true;
241 }
242 return true;
243 } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
244 if (is_non_decreasing) {
245 *is_non_decreasing = false;
246 }
247 return true;
248 }
249 return false;
250 }
251
IsElu(const NodeDef & node)252 bool IsElu(const NodeDef& node) { return node.op() == "Elu"; }
253
IsEluGrad(const NodeDef & node)254 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
255
IsQuantizationEmulation(const NodeDef & node)256 bool IsQuantizationEmulation(const NodeDef& node) {
257 const auto& op = node.op();
258 return absl::StartsWith(op, "QuantizeAndDequantize") ||
259 absl::StartsWith(op, "FakeQuantWithMinMax");
260 }
261
IsEnter(const NodeDef & node)262 bool IsEnter(const NodeDef& node) {
263 const auto& op = node.op();
264 return op == "Enter" || op == "RefEnter";
265 }
266
IsEqual(const NodeDef & node)267 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
268
IsExit(const NodeDef & node)269 bool IsExit(const NodeDef& node) {
270 const auto& op = node.op();
271 return op == "Exit" || op == "RefExit";
272 }
273
IsExp(const NodeDef & node)274 bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
275
IsFakeParam(const NodeDef & node)276 bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
277
IsFill(const NodeDef & node)278 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
279
IsFloorDiv(const NodeDef & node)280 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
281
IsFloorMod(const NodeDef & node)282 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
283
IsFusedBatchNorm(const NodeDef & node)284 bool IsFusedBatchNorm(const NodeDef& node) {
285 const auto& op = node.op();
286 return op == "FusedBatchNorm" || op == "FusedBatchNormV2" ||
287 op == "FusedBatchNormV3";
288 }
289
IsFusedBatchNormEx(const NodeDef & node)290 bool IsFusedBatchNormEx(const NodeDef& node) {
291 return node.op() == "_FusedBatchNormEx";
292 }
293
IsFusedBatchNormGrad(const NodeDef & node)294 bool IsFusedBatchNormGrad(const NodeDef& node) {
295 const auto& op = node.op();
296 return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2" ||
297 op == "FusedBatchNormGradV3";
298 }
299
IsGather(const NodeDef & node)300 bool IsGather(const NodeDef& node) {
301 const auto& op = node.op();
302 return op == "Gather" || op == "GatherV2" || op == "ResourceGather";
303 }
304
IsGreater(const NodeDef & node)305 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
306
IsGreaterEqual(const NodeDef & node)307 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
308
IsHostConstant(const NodeDef & node)309 bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; }
310
IsHistogramSummary(const NodeDef & node)311 bool IsHistogramSummary(const NodeDef& node) {
312 return node.op() == "HistogramSummary";
313 }
314
IsIdentity(const NodeDef & node)315 bool IsIdentity(const NodeDef& node) {
316 const auto& op = node.op();
317 return op == "Identity" || op == "RefIdentity";
318 }
319
IsIdentityN(const NodeDef & node)320 bool IsIdentityN(const NodeDef& node) {
321 const auto& op = node.op();
322 return op == "IdentityN";
323 }
324
IsIdentityNSingleInput(const NodeDef & node)325 bool IsIdentityNSingleInput(const NodeDef& node) {
326 return IsIdentityN(node) && node.attr().count("T") != 0 &&
327 node.attr().at("T").list().type_size() == 1;
328 }
329
IsIf(const NodeDef & node)330 bool IsIf(const NodeDef& node) {
331 const auto& op = node.op();
332 return op == "If" || op == "StatelessIf";
333 }
334
IsIgamma(const NodeDef & node)335 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
336
IsIgammac(const NodeDef & node)337 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
338
IsImag(const NodeDef & node)339 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
340
IsImmutableConst(const NodeDef & node)341 bool IsImmutableConst(const NodeDef& node) {
342 return node.op() == "ImmutableConst";
343 }
344
IsInvGrad(const NodeDef & node)345 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
346
IsLeakyRelu(const NodeDef & node)347 bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }
348
IsLeakyReluGrad(const NodeDef & node)349 bool IsLeakyReluGrad(const NodeDef& node) {
350 return node.op() == "LeakyReluGrad";
351 }
352
IsLess(const NodeDef & node)353 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
354
IsLessEqual(const NodeDef & node)355 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
356
IsLog(const NodeDef & node)357 bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
358
IsLogicalAnd(const NodeDef & node)359 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
360
IsLogicalNot(const NodeDef & node)361 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
362
IsLogicalOr(const NodeDef & node)363 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
364
IsLoopCond(const NodeDef & node)365 bool IsLoopCond(const NodeDef& node) { return node.op() == "LoopCond"; }
366
IsMatMul(const NodeDef & node)367 bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
368
IsMax(const NodeDef & node)369 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
370
IsMaximum(const NodeDef & node)371 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
372
IsMaxPoolGrad(const NodeDef & node)373 bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
374
IsMean(const NodeDef & node)375 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
376
IsMerge(const NodeDef & node)377 bool IsMerge(const NodeDef& node) {
378 const auto& op = node.op();
379 return op == "Merge" || op == "RefMerge" || op == "_XlaMerge";
380 }
381
IsMin(const NodeDef & node)382 bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
383
IsMinimum(const NodeDef & node)384 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
385
IsMirrorPad(const NodeDef & node)386 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
387
IsMirrorPadGrad(const NodeDef & node)388 bool IsMirrorPadGrad(const NodeDef& node) {
389 return node.op() == "MirrorPadGrad";
390 }
391
IsMod(const NodeDef & node)392 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
393
IsMul(const NodeDef & node)394 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
IsMulNoNan(const NodeDef & node)395 bool IsMulNoNan(const NodeDef& node) { return node.op() == "MulNoNan"; }
IsAnyMul(const NodeDef & node)396 bool IsAnyMul(const NodeDef& node) { return IsMul(node) || IsMulNoNan(node); }
397
IsNeg(const NodeDef & node)398 bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; }
399
IsNoOp(const NodeDef & node)400 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
401
IsNotEqual(const NodeDef & node)402 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
403
IsNextIteration(const NodeDef & node)404 bool IsNextIteration(const NodeDef& node) {
405 const auto& op = node.op();
406 return op == "NextIteration" || op == "RefNextIteration";
407 }
408
IsOnesLike(const NodeDef & node)409 bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
410
IsPack(const NodeDef & node)411 bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
412
IsPad(const NodeDef & node)413 bool IsPad(const NodeDef& node) {
414 const auto& op = node.op();
415 return op == "Pad" || op == "PadV2";
416 }
417
IsPartitionedCall(const NodeDef & node)418 bool IsPartitionedCall(const NodeDef& node) {
419 return node.op() == "PartitionedCall";
420 }
421
IsPlaceholder(const NodeDef & node)422 bool IsPlaceholder(const NodeDef& node) {
423 const auto& op = node.op();
424 return op == "Placeholder" || op == "PlaceholderV2" ||
425 op == "PlaceholderWithDefault";
426 }
427
IsPolygamma(const NodeDef & node)428 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
429
IsPow(const NodeDef & node)430 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
431
IsPrint(const NodeDef & node)432 bool IsPrint(const NodeDef& node) {
433 return node.op() == "Print" || node.op() == "PrintV2";
434 }
435
IsProd(const NodeDef & node)436 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
437
IsQuantizedMatMul(const NodeDef & node)438 bool IsQuantizedMatMul(const NodeDef& node) {
439 return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
440 }
441
IsQueue(const NodeDef & node)442 bool IsQueue(const NodeDef& node) {
443 return str_util::EndsWith(node.op(), "QueueV2");
444 }
445
IsRandomShuffle(const NodeDef & node)446 bool IsRandomShuffle(const NodeDef& node) {
447 return node.op() == "RandomShuffle";
448 }
449
IsRank(const NodeDef & node)450 bool IsRank(const NodeDef& node) { return node.op() == "Rank"; }
451
IsReadVariableOp(const NodeDef & node)452 bool IsReadVariableOp(const NodeDef& node) {
453 return node.op() == "ReadVariableOp";
454 }
455
IsReadVariablesOp(const NodeDef & node)456 bool IsReadVariablesOp(const NodeDef& node) {
457 return node.op() == "_ReadVariablesOp";
458 }
459
IsReal(const NodeDef & node)460 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
461
IsRealDiv(const NodeDef & node)462 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
463
IsReciprocalGrad(const NodeDef & node)464 bool IsReciprocalGrad(const NodeDef& node) {
465 return node.op() == "ReciprocalGrad";
466 }
467
IsRecv(const NodeDef & node)468 bool IsRecv(const NodeDef& node) {
469 return node.op() == "_Recv" || node.op() == "_HostRecv";
470 }
471
IsReduction(const NodeDef & node)472 bool IsReduction(const NodeDef& node) {
473 const auto& op = node.op();
474 return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
475 op == "Mean" || op == "Any" || op == "All";
476 }
477
IsRelu(const NodeDef & node)478 bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
479
IsRelu6(const NodeDef & node)480 bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6"; }
481
IsReluGrad(const NodeDef & node)482 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
483
IsRelu6Grad(const NodeDef & node)484 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
485
IsReshape(const NodeDef & node)486 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
487
IsRestore(const NodeDef & node)488 bool IsRestore(const NodeDef& node) {
489 return (node.op() == "Restore" || node.op() == "RestoreV2" ||
490 node.op() == "RestoreSlice");
491 }
492
IsRetval(const NodeDef & node)493 bool IsRetval(const NodeDef& node) {
494 return node.op() == "_Retval" || node.op() == "_DeviceRetval";
495 }
496
IsReverse(const NodeDef & node)497 bool IsReverse(const NodeDef& node) {
498 return node.op() == "Reverse" || node.op() == "ReverseV2";
499 }
500
IsReverseV2(const NodeDef & node)501 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
502
IsRsqrt(const NodeDef & node)503 bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
504
IsRsqrtGrad(const NodeDef & node)505 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
506
IsSelect(const NodeDef & node)507 bool IsSelect(const NodeDef& node) {
508 return node.op() == "Select" || node.op() == "SelectV2";
509 }
510
IsSeluGrad(const NodeDef & node)511 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
512
IsSend(const NodeDef & node)513 bool IsSend(const NodeDef& node) {
514 return node.op() == "_Send" || node.op() == "_HostSend";
515 }
516
IsShape(const NodeDef & node)517 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
518
IsShapeN(const NodeDef & node)519 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
520
IsShuffle(const NodeDef & node)521 bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
522
IsSigmoid(const NodeDef & node)523 bool IsSigmoid(const NodeDef& node) { return node.op() == "Sigmoid"; }
524
IsSigmoidGrad(const NodeDef & node)525 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
526
IsSize(const NodeDef & node)527 bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
528
IsSlice(const NodeDef & node)529 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
530
IsSnapshot(const NodeDef & node)531 bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
532
IsSoftmax(const NodeDef & node)533 bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
534
IsSoftplusGrad(const NodeDef & node)535 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
536
IsSoftsignGrad(const NodeDef & node)537 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
538
IsSplit(const NodeDef & node)539 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
540
IsSplitV(const NodeDef & node)541 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
542
IsSqrt(const NodeDef & node)543 bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
544
IsSqrtGrad(const NodeDef & node)545 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
546
IsSquare(const NodeDef & node)547 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
548
IsSquaredDifference(const NodeDef & node)549 bool IsSquaredDifference(const NodeDef& node) {
550 return node.op() == "SquaredDifference";
551 }
552
IsSqueeze(const NodeDef & node)553 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
554
IsStackOp(const NodeDef & node)555 bool IsStackOp(const NodeDef& node) {
556 return node.op() == "Stack" || node.op() == "StackV2";
557 }
IsStackCloseOp(const NodeDef & node)558 bool IsStackCloseOp(const NodeDef& node) {
559 return node.op() == "StackClose" || node.op() == "StackCloseV2";
560 }
IsStackPushOp(const NodeDef & node)561 bool IsStackPushOp(const NodeDef& node) {
562 return node.op() == "StackPush" || node.op() == "StackPushV2";
563 }
IsStackPopOp(const NodeDef & node)564 bool IsStackPopOp(const NodeDef& node) {
565 return node.op() == "StackPop" || node.op() == "StackPopV2";
566 }
567
IsStatefulPartitionedCall(const NodeDef & node)568 bool IsStatefulPartitionedCall(const NodeDef& node) {
569 return node.op() == "StatefulPartitionedCall";
570 }
571
IsStopGradient(const NodeDef & node)572 bool IsStopGradient(const NodeDef& node) {
573 const auto& op = node.op();
574 return op == "StopGradient" || op == "PreventGradient";
575 }
576
IsStridedSlice(const NodeDef & node)577 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
578
IsStridedSliceGrad(const NodeDef & node)579 bool IsStridedSliceGrad(const NodeDef& node) {
580 return node.op() == "StridedSliceGrad";
581 }
582
IsStringToHashBucketFast(const NodeDef & node)583 bool IsStringToHashBucketFast(const NodeDef& node) {
584 return node.op() == "StringToHashBucketFast";
585 }
586
IsSub(const NodeDef & node)587 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
588
IsSum(const NodeDef & node)589 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
590
IsSwitch(const NodeDef & node)591 bool IsSwitch(const NodeDef& node) {
592 const auto& op = node.op();
593 return op == "_SwitchN" || op == "Switch" || op == "RefSwitch";
594 }
595
IsSymbolicGradient(const NodeDef & node)596 bool IsSymbolicGradient(const NodeDef& node) {
597 return node.op() == "SymbolicGradient";
598 }
599
IsTanh(const NodeDef & node)600 bool IsTanh(const NodeDef& node) { return node.op() == "Tanh"; }
601
IsTanhGrad(const NodeDef & node)602 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
603
IsTensorArray(const NodeDef & node)604 bool IsTensorArray(const NodeDef& node) {
605 static const gtl::FlatSet<string>* const kTensorArrayOps =
606 CHECK_NOTNULL((new gtl::FlatSet<string>{
607 "TensorArray",
608 "TensorArrayV2",
609 "TensorArrayV3",
610 "TensorArrayGrad",
611 "TensorArrayGradV2",
612 "TensorArrayGradV3",
613 "TensorArrayGradWithShape",
614 "TensorArrayWrite",
615 "TensorArrayWriteV2",
616 "TensorArrayWriteV3",
617 "TensorArrayRead",
618 "TensorArrayReadV2",
619 "TensorArrayReadV3",
620 "TensorArrayConcat",
621 "TensorArrayConcatV2",
622 "TensorArrayConcatV3",
623 "TensorArraySplit",
624 "TensorArraySplitV2",
625 "TensorArraySplitV3",
626 "TensorArraySize",
627 "TensorArraySizeV2",
628 "TensorArraySizeV3",
629 "TensorArrayClose",
630 "TensorArrayCloseV2",
631 "TensorArrayCloseV3",
632 }));
633 return kTensorArrayOps->count(node.op()) > 0;
634 }
635
IsTile(const NodeDef & node)636 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
637
IsTranspose(const NodeDef & node)638 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
639
IsTruncateDiv(const NodeDef & node)640 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
641
IsTruncateMod(const NodeDef & node)642 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
643
IsUnique(const NodeDef & node)644 bool IsUnique(const NodeDef& node) {
645 const auto& op = node.op();
646 return op == "Unique" || op == "UniqueV2";
647 }
648
IsUnpack(const NodeDef & node)649 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
650
IsVariable(const NodeDef & node)651 bool IsVariable(const NodeDef& node) {
652 const auto& op = node.op();
653 return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
654 op == "VarHandleOp" || op == "ReadVariableOp" ||
655 op == "_VarHandlesOp" || op == "_ReadVariablesOp";
656 }
657
IsWhile(const NodeDef & node)658 bool IsWhile(const NodeDef& node) {
659 const auto& op = node.op();
660 return op == "While" || op == "StatelessWhile";
661 }
662
IsXdivy(const NodeDef & node)663 bool IsXdivy(const NodeDef& node) { return node.op() == "Xdivy"; }
664
IsZerosLike(const NodeDef & node)665 bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
666
IsZeta(const NodeDef & node)667 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
668
669 namespace {
GetBoolAttr(const NodeDef & node,const string & name)670 bool GetBoolAttr(const NodeDef& node, const string& name) {
671 return node.attr().count(name) > 0 && node.attr().at(name).b();
672 }
673 } // namespace
674
IsPersistent(const NodeDef & node)675 bool IsPersistent(const NodeDef& node) {
676 return IsConstant(node) || IsVariable(node) || IsHostConstant(node);
677 }
678
HasRefInput(const NodeDef & node)679 bool HasRefInput(const NodeDef& node) {
680 const OpDef* op_def;
681 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
682 if (!status.ok()) {
683 return false;
684 }
685 // Nodes such as Assign or AssignAdd modify one of their inputs.
686 for (const auto& input : op_def->input_arg()) {
687 if (input.is_ref()) {
688 return true;
689 }
690 }
691 return false;
692 }
693
IsDataset(const NodeDef & node)694 bool IsDataset(const NodeDef& node) {
695 const string& op = node.op();
696 // See `GetNodeClassForOp` in core/graph/graph.cc.
697 return op == "IteratorGetNext" || op == "IteratorGetNextSync" ||
698 op == "DatasetToSingleElement" || op == "ReduceDataset";
699 }
700
IsStateful(const NodeDef node,const OpRegistryInterface * op_registry)701 bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) {
702 const OpDef* op_def = nullptr;
703 const string& op_name = node.op();
704 Status status = op_registry->LookUpOpDef(op_name, &op_def);
705 if (!status.ok()) {
706 LOG(WARNING) << "Failed to lookup OpDef for " << op_name
707 << ". Error: " << status.error_message();
708 return false;
709 }
710 return op_def->is_stateful();
711 }
712
IsStateful(const NodeDef node)713 bool IsStateful(const NodeDef node) {
714 return IsStateful(node, OpRegistry::Global());
715 }
716
IsFreeOfSideEffect(const NodeDef & node,const OpRegistryInterface * op_registry)717 bool IsFreeOfSideEffect(const NodeDef& node,
718 const OpRegistryInterface* op_registry) {
719 // Placeholders must be preserved to keep the graph feedable.
720 if (IsPlaceholder(node)) {
721 return false;
722 }
723 const OpDef* op_def = nullptr;
724 const string& op_name = node.op();
725 Status status = op_registry->LookUpOpDef(op_name, &op_def);
726 if (!status.ok()) {
727 return false;
728 }
729 if (op_def->is_stateful()) {
730 return false;
731 }
732 // Nodes such as Assign or AssignAdd modify one of their inputs.
733 for (const auto& input : op_def->input_arg()) {
734 if (input.is_ref()) {
735 return false;
736 }
737 }
738 // Queue ops modify the queue which is a side effect.
739 if (node.op().find("Queue") != string::npos) {
740 return false;
741 }
742 // Sending a tensor via a network is a side effect.
743 if (IsSend(node)) {
744 return false;
745 }
746 return !ModifiesInputsInPlace(node);
747 }
748
IsFreeOfSideEffect(const NodeDef & node)749 bool IsFreeOfSideEffect(const NodeDef& node) {
750 return IsFreeOfSideEffect(node, OpRegistry::Global());
751 }
752
ModifiesInputsInPlace(const NodeDef & node)753 bool ModifiesInputsInPlace(const NodeDef& node) {
754 // Some nodes do in-place updates on regular tensor inputs.
755 const string& op_name = node.op();
756
757 // Ops that modify resource variables effectively modify one of their inputs.
758 if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
759 op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
760 op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
761 op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
762 op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
763 return false;
764 }
765
766 string lower_op_name = op_name;
767 std::transform(lower_op_name.begin(), lower_op_name.end(),
768 lower_op_name.begin(), ::tolower);
769 if (absl::StrContains(lower_op_name, "inplace")) {
770 return true;
771 }
772 return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
773 }
774
ModifiesFrameInfo(const NodeDef & node)775 bool ModifiesFrameInfo(const NodeDef& node) {
776 return IsEnter(node) || IsExit(node) || IsNextIteration(node);
777 }
778
779 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \
780 bool Is##PROPERTY_CAP(const NodeDef& node) { \
781 if (node.op() == "Add") { \
782 /* Workaround for "Add" not being marked is_commutative and */ \
783 /* is_aggregate. (See cl/173915048). */ \
784 const auto type = GetDataTypeFromAttr(node, "T"); \
785 return type != DT_INVALID && type != DT_STRING; \
786 } \
787 const OpDef* op_def = nullptr; \
788 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
789 return status.ok() && op_def->is_##PROPERTY(); \
790 }
791
OPDEF_PROPERTY_HELPER(Aggregate,aggregate)792 OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
793 OPDEF_PROPERTY_HELPER(Commutative, commutative)
794
795 bool IsInvolution(const NodeDef& node) {
796 static const gtl::FlatSet<string>* const kInvolutionOps =
797 CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
798 "Neg", "LogicalNot"}));
799 return kInvolutionOps->count(node.op()) > 0;
800 }
801
IsValueAndOrderAndShapePreserving(const NodeDef & node)802 bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
803 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
804 return true;
805 }
806 static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
807 CHECK_NOTNULL((new const gtl::FlatSet<string>{
808 "CheckNumerics",
809 "DebugGradientIdentity",
810 "DeepCopy",
811 "Enter",
812 "Exit",
813 "PreventGradient",
814 "Print",
815 "Snapshot",
816 "StopGradient",
817 }));
818 return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
819 IsIdentity(node);
820 }
821
IsValueAndOrderPreserving(const NodeDef & node)822 bool IsValueAndOrderPreserving(const NodeDef& node) {
823 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
824 return true;
825 }
826 static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
827 CHECK_NOTNULL((new const gtl::FlatSet<string>{
828 "ExpandDims",
829 "Reshape",
830 "Squeeze",
831 }));
832 return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
833 IsValueAndOrderAndShapePreserving(node);
834 }
835
IsValuePreserving(const NodeDef & node)836 bool IsValuePreserving(const NodeDef& node) {
837 static const gtl::FlatSet<string>* const kValuePreservingOps =
838 CHECK_NOTNULL((new gtl::FlatSet<string>{
839 "InvertPermutation",
840 "Reverse",
841 "ReverseV2",
842 "Roll",
843 "Transpose",
844 "DepthToSpace",
845 "SpaceToDepth",
846 "BatchToSpace",
847 "BatchToSpaceND",
848 "SpaceToBatch",
849 "SpaceToBatchND",
850 }));
851 return IsValueAndOrderPreserving(node) ||
852 kValuePreservingOps->count(node.op()) > 0;
853 }
854
IsUnaryElementWise(const NodeDef & node)855 bool IsUnaryElementWise(const NodeDef& node) {
856 static const gtl::FlatSet<string>* const kElementWiseOps =
857 CHECK_NOTNULL((new gtl::FlatSet<string>{
858 "Abs", "Acos", "Acosh", "Asin", "Asinh",
859 "Atan", "Atanh", "Ceil", "ComplexAbs", "Conj",
860 "Cos", "Cosh", "Digamma", "Elu", "Erf",
861 "Erfc", "Exp", "Expm1", "Floor", "Inv",
862 "Invert", "Isinf", "Isnan", "Isfinite", "Lgamma",
863 "Log", "Log1p", "LogicalNot", "Neg", "Reciprocal",
864 "Relu", "Relu6", "Rint", "Round", "Selu",
865 "Rsqrt", "Sigmoid", "Sign", "Sin", "SinH",
866 "Softplus", "Softsign", "Sqrt", "Square", "Tan",
867 "Tanh",
868 }));
869 return kElementWiseOps->count(node.op()) > 0 ||
870 IsValueAndOrderAndShapePreserving(node);
871 }
872
HasOpDef(const NodeDef & node)873 bool HasOpDef(const NodeDef& node) {
874 const OpDef* op_def = nullptr;
875 return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
876 }
877
IsIdempotent(const NodeDef & node)878 bool IsIdempotent(const NodeDef& node) {
879 return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
880 !ModifiesFrameInfo(node);
881 }
882
NeverForwardsInputs(const NodeDef & node)883 bool NeverForwardsInputs(const NodeDef& node) {
884 static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL(
885 (new gtl::FlatSet<string>{"ArgMax",
886 "ArgMin",
887 "AudioSpectrogram",
888 "AvgPool",
889 "BatchMatMul",
890 "BatchMatMulV2",
891 "BatchNormWithGlobalNormalization",
892 "BatchToSpace",
893 "BatchToSpaceND",
894 "Bincount",
895 "BroadcastArgs",
896 "BroadcastGradientArgs",
897 "Bucketize",
898 "CTCBeamSearchDecoder",
899 "CTCGreedyDecoder",
900 "CTCLoss",
901 "CompareAndBitpack",
902 "ComplexAbs",
903 "Concat",
904 "ConcatOffset",
905 "ConcatV2",
906 "Conv2D",
907 "Copy",
908 "CopyHost",
909 "Cross",
910 "CudnnRNN",
911 "CudnnRNNBackprop",
912 "CudnnRNNBackpropV2",
913 "CudnnRNNBackpropV3",
914 "CudnnRNNCanonicalToParams",
915 "CudnnRNNCanonicalToParamsV2",
916 "CudnnRNNParamsSize",
917 "CudnnRNNParamsToCanonical",
918 "CudnnRNNParamsToCanonicalV2",
919 "CudnnRNNV2",
920 "CudnnRNNV3",
921 "CumProd",
922 "CumSum",
923 "DebugNanCount",
924 "DebugNumericSummary",
925 "DecodeProtoV2",
926 "DecodeWav",
927 "DeepCopy",
928 "DepthToSpace",
929 "Dequantize",
930 "Diag",
931 "DiagPart",
932 "EditDistance",
933 "Empty",
934 "EncodeProtoV2",
935 "EncodeWav",
936 "ExtractImagePatches",
937 "ExtractVolumePatches",
938 "Fill",
939 "Gather",
940 "GatherNd",
941 "GatherV2",
942 "HistogramFixedWidth",
943 "InvertPermutation",
944 "IsInf",
945 "IsNan",
946 "Isfinite",
947 "LinSpace",
948 "LowerBound",
949 "MatMul",
950 "MatrixDiag",
951 "MatrixDiagPart",
952 "MatrixDiagPartV2",
953 "MatrixDiagV2",
954 "Mfcc",
955 "Multinomial",
956 "OneHot",
957 "Pack",
958 "ParameterizedTruncatedNormal",
959 "PopulationCount",
960 "RandomGamma",
961 "RandomPoisson",
962 "RandomPoissonV2",
963 "RandomStandardNormal",
964 "RandomUniform",
965 "RandomUniformInt",
966 "Range",
967 "Rank",
968 "RequantizationRange",
969 "Requantize",
970 "ReverseSequence",
971 "Shape",
972 "ShapeN",
973 "Size",
974 "SpaceToBatch",
975 "SpaceToBatchND",
976 "SpaceToDepth",
977 "SparseMatMul",
978 "Split",
979 "SplitV",
980 "TruncatedNormal",
981 "Unique",
982 "UniqueV2",
983 "UniqueWithCounts",
984 "UniqueWithCountsV2",
985 "Unpack",
986 "UnravelIndex",
987 "UpperBound",
988 "Where"}));
989 const string& op_name = node.op();
990 return kNonForwardingOps->count(op_name) > 0 ||
991 absl::StrContains(op_name, "Segment") ||
992 absl::StartsWith(op_name, "Quantize");
993 }
994
IsXlaLaunch(const NodeDef & node)995 bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; }
996
997 } // namespace grappler
998 } // end namespace tensorflow
999