1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
16
17 #include <set>
18 #include <string>
19
20 #include "tensorflow/core/framework/op.h"
21 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops_internal.h"
22
23 namespace tflite {
24 namespace flex {
25
GetFlexAllowlist()26 const std::set<std::string>& GetFlexAllowlist() {
27 // LINT.IfChange
28 static const std::set<std::string>* allowlisted_flex_ops =
29 new std::set<std::string>({
30 // go/keep-sorted start
31 "Abort",
32 "Abs",
33 "Add",
34 "AddN",
35 "AddV2",
36 "AdjustContrast",
37 "AdjustContrastv2",
38 "AdjustHue",
39 "AdjustSaturation",
40 "All",
41 "Angle",
42 "Any",
43 "ApplyAdaMax",
44 "ApplyAdadelta",
45 "ApplyAdagrad",
46 "ApplyAdagradDA",
47 "ApplyAdagradV2",
48 "ApplyAdam",
49 "ApplyAddSign",
50 "ApplyCenteredRMSProp",
51 "ApplyFtrl",
52 "ApplyFtrlV2",
53 "ApplyGradientDescent",
54 "ApplyMomentum",
55 "ApplyPowerSign",
56 "ApplyProximalAdagrad",
57 "ApplyProximalGradientDescent",
58 "ApplyRMSProp",
59 "ApproximateEqual",
60 "ArgMax",
61 "ArgMin",
62 "AsString",
63 "Assert",
64 "Assign",
65 "AssignAdd",
66 "AssignAddVariableOp",
67 "AssignSub",
68 "AssignSubVariableOp",
69 "AssignVariableOp",
70 "Atan",
71 "Atan2",
72 "AudioSpectrogram",
73 "AvgPool",
74 "AvgPool3D",
75 "AvgPool3DGrad",
76 "AvgPoolGrad",
77 "BatchCholesky",
78 "BatchDatasetV2",
79 "BatchMatMul",
80 "BatchMatMulV2",
81 "BatchMatrixBandPart",
82 "BatchMatrixDeterminant",
83 "BatchMatrixDiag",
84 "BatchMatrixDiagPart",
85 "BatchMatrixInverse",
86 "BatchMatrixSetDiag",
87 "BatchMatrixTriangularSolve",
88 "BatchNormWithGlobalNormalization",
89 "BatchNormWithGlobalNormalizationGrad",
90 "BatchToSpace",
91 "BatchToSpaceND",
92 "BiasAdd",
93 "BiasAddGrad",
94 "BiasAddV1",
95 "Bincount",
96 "Bitcast",
97 "BitwiseAnd",
98 "BitwiseOr",
99 "BitwiseXor",
100 "BroadcastArgs",
101 "BroadcastGradientArgs",
102 "BroadcastTo",
103 "Bucketize",
104 "CTCBeamSearchDecoder",
105 "CTCGreedyDecoder",
106 "Cast",
107 "Ceil",
108 "CheckNumerics",
109 "CheckNumericsV2",
110 "Cholesky",
111 "CombinedNonMaxSuppression",
112 "Complex",
113 "ComplexAbs",
114 "Concat",
115 "ConcatOffset",
116 "ConcatV2",
117 "Conj",
118 "ConjugateTranspose",
119 "Const",
120 "ControlTrigger",
121 "Conv2D",
122 "Conv2DBackpropFilter",
123 "Conv2DBackpropInput",
124 "Conv3D",
125 "Conv3DBackpropFilter",
126 "Conv3DBackpropFilterV2",
127 "Conv3DBackpropInput",
128 "Conv3DBackpropInputV2",
129 "Cos",
130 "Cosh",
131 "CropAndResize",
132 "CropAndResizeGradBoxes",
133 "CropAndResizeGradImage",
134 "Cumprod",
135 "Cumsum",
136 "CumulativeLogsumexp",
137 "DataFormatDimMap",
138 "DataFormatVecPermute",
139 "DebugGradientIdentity",
140 "DebugGradientRefIdentity",
141 "DecodeAndCropJpeg",
142 "DecodeBase64",
143 "DecodeBmp",
144 "DecodeGif",
145 "DecodeImage",
146 "DecodeJpeg",
147 "DecodePaddedRaw",
148 "DecodePng",
149 "DecodeRaw",
150 "DecodeWav",
151 "DeepCopy",
152 "DeleteSessionTensor",
153 "DenseBincount",
154 "DenseToDenseSetOperation",
155 "DenseToSparseSetOperation",
156 "DepthToSpace",
157 "DepthwiseConv2dNative",
158 "DepthwiseConv2dNativeBackpropFilter",
159 "DepthwiseConv2dNativeBackpropInput",
160 "Dequantize",
161 "DestroyResourceOp",
162 "DestroyTemporaryVariable",
163 "Diag",
164 "DiagPart",
165 "Dilation2D",
166 "Dilation2DBackpropFilter",
167 "Dilation2DBackpropInput",
168 "Div",
169 "DivNoNan",
170 "DynamicPartition",
171 "DynamicStitch",
172 "Einsum",
173 "Elu",
174 "EluGrad",
175 "Empty",
176 "EmptyTensorList",
177 "EmptyTensorMap",
178 "EncodeBase64",
179 "EncodeJpeg",
180 "EncodeJpegVariableQuality",
181 "EncodePng",
182 "EncodeWav",
183 "EnsureShape",
184 "Enter",
185 "Equal",
186 "Erf",
187 "Exit",
188 "Exp",
189 "ExpandDims",
190 "ExtractImagePatches",
191 "FFT",
192 "FFT2D",
193 "FFT3D",
194 "FIFOQueue",
195 "FIFOQueueV2",
196 "FakeQuantWithMinMaxArgs",
197 "FakeQuantWithMinMaxArgsGradient",
198 "FakeQuantWithMinMaxVars",
199 "FakeQuantWithMinMaxVarsGradient",
200 "FakeQuantWithMinMaxVarsPerChannel",
201 "FakeQuantWithMinMaxVarsPerChannelGradient",
202 "FakeQueue",
203 "Fill",
204 "FilterDataset",
205 "FinalizeDataset",
206 "Fingerprint",
207 "FlatMapDataset",
208 "Floor",
209 "FloorDiv",
210 "FloorMod",
211 "FusedBatchNorm",
212 "FusedBatchNormGrad",
213 "FusedBatchNormGradV2",
214 "FusedBatchNormGradV3",
215 "FusedBatchNormV2",
216 "FusedBatchNormV3",
217 "FusedPadConv2D",
218 "FusedResizeAndPadConv2D",
219 "Gather",
220 "GatherNd",
221 "GatherV2",
222 "GetSessionHandle",
223 "GetSessionHandleV2",
224 "GetSessionTensor",
225 "Greater",
226 "GreaterEqual",
227 "HSVToRGB",
228 "HashTable",
229 "HashTableV2",
230 "HistogramSummary",
231 "IFFT",
232 "IFFT2D",
233 "IFFT3D",
234 "IRFFT",
235 "IRFFT2D",
236 "IRFFT3D",
237 "Identity",
238 "IdentityN",
239 "Imag",
240 "ImageProjectiveTransformV2",
241 "ImageProjectiveTransformV3",
242 "ImmutableConst",
243 "InTopK",
244 "InTopKV2",
245 "InitializeTable",
246 "InitializeTableFromDataset",
247 "InitializeTableFromTextFile",
248 "InitializeTableFromTextFileV2",
249 "InitializeTableV2",
250 "InplaceAdd",
251 "InplaceSub",
252 "InplaceUpdate",
253 "Inv",
254 "InvGrad",
255 "Invert",
256 "InvertPermutation",
257 "IsFinite",
258 "IsNan",
259 "IsVariableInitialized",
260 "LRN",
261 "LeakyRelu",
262 "LeakyReluGrad",
263 "LeftShift",
264 "Less",
265 "LessEqual",
266 "LinSpace",
267 "ListDiff",
268 "Log",
269 "LogMatrixDeterminant",
270 "LogSoftmax",
271 "LogicalAnd",
272 "LogicalNot",
273 "LogicalOr",
274 "LookupTableExport",
275 "LookupTableExportV2",
276 "LookupTableFind",
277 "LookupTableFindV2",
278 "LookupTableImport",
279 "LookupTableImportV2",
280 "LookupTableInsert",
281 "LookupTableInsertV2",
282 "LookupTableRemoveV2",
283 "LookupTableSize",
284 "LookupTableSizeV2",
285 "LoopCond",
286 "MapDataset",
287 "MatMul",
288 "MatrixBandPart",
289 "MatrixDeterminant",
290 "MatrixDiag",
291 "MatrixDiagPart",
292 "MatrixDiagPartV2",
293 "MatrixDiagPartV3",
294 "MatrixDiagV2",
295 "MatrixDiagV3",
296 "MatrixInverse",
297 "MatrixSetDiag",
298 "MatrixSetDiagV2",
299 "MatrixSetDiagV3",
300 "MatrixTriangularSolve",
301 "Max",
302 "MaxPool",
303 "MaxPool3D",
304 "MaxPool3DGrad",
305 "MaxPool3DGradGrad",
306 "MaxPoolGrad",
307 "MaxPoolGradGrad",
308 "MaxPoolGradGradV2",
309 "MaxPoolGradV2",
310 "MaxPoolGradWithArgmax",
311 "MaxPoolV2",
312 "MaxPoolWithArgmax",
313 "Maximum",
314 "Mean",
315 "Merge",
316 "MergeSummary",
317 "MergeV2Checkpoints",
318 "Mfcc",
319 "Min",
320 "Minimum",
321 "MirrorPad",
322 "MirrorPadGrad",
323 "ModelDataset",
324 "Mul",
325 "MulNoNan",
326 "Multinomial",
327 "MutableDenseHashTable",
328 "MutableDenseHashTableV2",
329 "MutableHashTable",
330 "MutableHashTableOfTensors",
331 "MutableHashTableOfTensorsV2",
332 "MutableHashTableV2",
333 "Neg",
334 "NextIteration",
335 "NoOp",
336 "NonMaxSuppression",
337 "NonMaxSuppressionV2",
338 "NonMaxSuppressionV3",
339 "NonMaxSuppressionV4",
340 "NonMaxSuppressionV5",
341 "NonMaxSuppressionWithOverlaps",
342 "NotEqual",
343 "OneHot",
344 "OnesLike",
345 "OptimizeDatasetV2",
346 "OptionalFromValue",
347 "OptionalGetValue",
348 "OptionalHasValue",
349 "OptionalNone",
350 "Pack",
351 "Pad",
352 "PadV2",
353 "PaddingFIFOQueue",
354 "PaddingFIFOQueueV2",
355 "ParallelConcat",
356 "ParallelDynamicStitch",
357 "ParseExample",
358 "ParseExampleV2",
359 "ParseSequenceExample",
360 "ParseSequenceExampleV2",
361 "ParseSingleExample",
362 "ParseSingleSequenceExample",
363 "Placeholder",
364 "PlaceholderV2",
365 "PlaceholderWithDefault",
366 "PopulationCount",
367 "Pow",
368 "PreventGradient",
369 "Print",
370 "PrintV2",
371 "Prod",
372 "Qr",
373 "QuantizeDownAndShrinkRange",
374 "QuantizeV2",
375 "QuantizedAdd",
376 "QuantizedAvgPool",
377 "QuantizedBatchNormWithGlobalNormalization",
378 "QuantizedBiasAdd",
379 "QuantizedConcat",
380 "QuantizedConv2D",
381 "QuantizedInstanceNorm",
382 "QuantizedMatMul",
383 "QuantizedMaxPool",
384 "QuantizedMul",
385 "QuantizedRelu",
386 "QuantizedRelu6",
387 "QuantizedReshape",
388 "QuantizedResizeBilinear",
389 "QueueClose",
390 "QueueCloseV2",
391 "QueueDequeue",
392 "QueueDequeueMany",
393 "QueueDequeueManyV2",
394 "QueueDequeueUpTo",
395 "QueueDequeueUpToV2",
396 "QueueDequeueV2",
397 "QueueEnqueue",
398 "QueueEnqueueMany",
399 "QueueEnqueueManyV2",
400 "QueueEnqueueV2",
401 "QueueIsClosed",
402 "QueueIsClosedV2",
403 "QueueSize",
404 "QueueSizeV2",
405 "RFFT",
406 "RFFT2D",
407 "RFFT3D",
408 "RGBToHSV",
409 "RaggedBincount",
410 "RaggedGather",
411 "RaggedRange",
412 "RaggedTensorFromVariant",
413 "RaggedTensorToSparse",
414 "RaggedTensorToTensor",
415 "RaggedTensorToVariant",
416 "RaggedTensorToVariantGradient",
417 "RandomGamma",
418 "RandomPoisson",
419 "RandomPoissonV2",
420 "RandomShuffle",
421 "RandomStandardNormal",
422 "RandomUniform",
423 "RandomUniformInt",
424 "Range",
425 "Rank",
426 "ReadFile",
427 "ReadVariableOp",
428 "Real",
429 "RealDiv",
430 "Reciprocal",
431 "ReciprocalGrad",
432 "Recv",
433 "ReduceDataset",
434 "ReduceJoin",
435 "RefEnter",
436 "RefExit",
437 "RefIdentity",
438 "RefMerge",
439 "RefNextIteration",
440 "RefSelect",
441 "RefSwitch",
442 "RegexFullMatch",
443 "RegexReplace",
444 "Relu",
445 "Relu6",
446 "Relu6Grad",
447 "ReluGrad",
448 "RemoteCall",
449 "RepeatDataset",
450 "RequantizationRange",
451 "Requantize",
452 "Reshape",
453 "ResizeBicubic",
454 "ResizeBicubicGrad",
455 "ResizeBilinear",
456 "ResizeBilinearGrad",
457 "ResizeNearestNeighbor",
458 "ResizeNearestNeighborGrad",
459 "ResourceApplyAdaMax",
460 "ResourceApplyAdadelta",
461 "ResourceApplyAdagrad",
462 "ResourceApplyAdagradDA",
463 "ResourceApplyAdagradV2",
464 "ResourceApplyAdam",
465 "ResourceApplyAdamWithAmsgrad",
466 "ResourceApplyAddSign",
467 "ResourceApplyCenteredRMSProp",
468 "ResourceApplyFtrl",
469 "ResourceApplyFtrlV2",
470 "ResourceApplyGradientDescent",
471 "ResourceApplyKerasMomentum",
472 "ResourceApplyMomentum",
473 "ResourceApplyPowerSign",
474 "ResourceApplyProximalAdagrad",
475 "ResourceApplyProximalGradientDescent",
476 "ResourceApplyRMSProp",
477 "ResourceGather",
478 "ResourceGatherNd",
479 "ResourceScatterAdd",
480 "ResourceScatterDiv",
481 "ResourceScatterMax",
482 "ResourceScatterMin",
483 "ResourceScatterMul",
484 "ResourceScatterNdAdd",
485 "ResourceScatterNdMax",
486 "ResourceScatterNdMin",
487 "ResourceScatterNdSub",
488 "ResourceScatterNdUpdate",
489 "ResourceScatterSub",
490 "ResourceScatterUpdate",
491 "ResourceSparseApplyAdadelta",
492 "ResourceSparseApplyAdagrad",
493 "ResourceSparseApplyAdagradDA",
494 "ResourceSparseApplyAdagradV2",
495 "ResourceSparseApplyCenteredRMSProp",
496 "ResourceSparseApplyFtrl",
497 "ResourceSparseApplyFtrlV2",
498 "ResourceSparseApplyKerasMomentum",
499 "ResourceSparseApplyMomentum",
500 "ResourceSparseApplyProximalAdagrad",
501 "ResourceSparseApplyProximalGradientDescent",
502 "ResourceSparseApplyRMSProp",
503 "ResourceStridedSliceAssign",
504 "Restore",
505 "RestoreSlice",
506 "RestoreV2",
507 "Reverse",
508 "ReverseSequence",
509 "ReverseV2",
510 "RightShift",
511 "Roll",
512 "Round",
513 "Rsqrt",
514 "RsqrtGrad",
515 "SampleDistortedBoundingBox",
516 "SampleDistortedBoundingBoxV2",
517 "Save",
518 "SaveSlices",
519 "SaveV2",
520 "ScalarSummary",
521 "ScatterNd",
522 "ScatterNdAdd",
523 "ScatterNdMax",
524 "ScatterNdMin",
525 "ScatterNdNonAliasingAdd",
526 "ScatterNdSub",
527 "ScatterNdUpdate",
528 "SegmentMax",
529 "SegmentMean",
530 "SegmentMin",
531 "SegmentProd",
532 "SegmentSum",
533 "Select",
534 "SelectV2",
535 "Selu",
536 "SeluGrad",
537 "Send",
538 "SerializeTensor",
539 "Shape",
540 "ShapeN",
541 "ShardedFilename",
542 "ShardedFilespec",
543 "Sigmoid",
544 "SigmoidGrad",
545 "Sign",
546 "Sin",
547 "Sinh",
548 "Size",
549 "Slice",
550 "Softmax",
551 "SoftmaxCrossEntropyWithLogits",
552 "Softplus",
553 "SoftplusGrad",
554 "Softsign",
555 "SoftsignGrad",
556 "SpaceToBatch",
557 "SpaceToBatchND",
558 "SpaceToDepth",
559 "SparseAdd",
560 "SparseApplyAdadelta",
561 "SparseApplyAdagrad",
562 "SparseApplyAdagradDA",
563 "SparseApplyAdagradV2",
564 "SparseApplyCenteredRMSProp",
565 "SparseApplyFtrl",
566 "SparseApplyFtrlV2",
567 "SparseApplyMomentum",
568 "SparseApplyProximalAdagrad",
569 "SparseApplyProximalGradientDescent",
570 "SparseApplyRMSProp",
571 "SparseBincount",
572 "SparseCross",
573 "SparseCrossHashed",
574 "SparseCrossV2",
575 "SparseFillEmptyRows",
576 "SparseFillEmptyRowsGrad",
577 "SparseReduceSum",
578 "SparseReorder",
579 "SparseReshape",
580 "SparseSegmentMean",
581 "SparseSegmentMeanGrad",
582 "SparseSegmentMeanWithNumSegments",
583 "SparseSegmentSqrtN",
584 "SparseSegmentSqrtNGrad",
585 "SparseSegmentSqrtNWithNumSegments",
586 "SparseSegmentSum",
587 "SparseSegmentSumGrad",
588 "SparseSegmentSumWithNumSegments",
589 "SparseSlice",
590 "SparseSoftmaxCrossEntropyWithLogits",
591 "SparseTensorDenseMatMul",
592 "SparseToDense",
593 "SparseToSparseSetOperation",
594 "Split",
595 "SplitV",
596 "Sqrt",
597 "SqrtGrad",
598 "Square",
599 "SquaredDifference",
600 "Squeeze",
601 "Stack",
602 "StackClose",
603 "StackCloseV2",
604 "StackPop",
605 "StackPopV2",
606 "StackPush",
607 "StackPushV2",
608 "StackV2",
609 "StatelessMultinomial",
610 "StatelessRandomGammaV2",
611 "StatelessRandomGetAlg",
612 "StatelessRandomGetKeyCounter",
613 "StatelessRandomGetKeyCounterAlg",
614 "StatelessRandomNormal",
615 "StatelessRandomNormalV2",
616 "StatelessRandomPoisson",
617 "StatelessRandomUniform",
618 "StatelessRandomUniformFullInt",
619 "StatelessRandomUniformFullIntV2",
620 "StatelessRandomUniformInt",
621 "StatelessRandomUniformIntV2",
622 "StatelessRandomUniformV2",
623 "StatelessSampleDistortedBoundingBox",
624 "StatelessTruncatedNormal",
625 "StatelessTruncatedNormalV2",
626 "StaticRegexFullMatch",
627 "StaticRegexReplace",
628 "StopGradient",
629 "StridedSlice",
630 "StridedSliceAssign",
631 "StridedSliceGrad",
632 "StringFormat",
633 "StringJoin",
634 "StringLength",
635 "StringLower",
636 "StringSplit",
637 "StringSplitV2",
638 "StringStrip",
639 "StringToHashBucket",
640 "StringToHashBucketFast",
641 "StringToHashBucketStrong",
642 "StringToNumber",
643 "Sub",
644 "Substr",
645 "Sum",
646 "Switch",
647 "SymbolicGradient",
648 "TakeDataset",
649 "TakeWhileDataset",
650 "Tan",
651 "Tanh",
652 "TanhGrad",
653 "TemporaryVariable",
654 "TensorArray",
655 "TensorArrayClose",
656 "TensorArrayCloseV2",
657 "TensorArrayCloseV3",
658 "TensorArrayConcat",
659 "TensorArrayConcatV2",
660 "TensorArrayConcatV3",
661 "TensorArrayGather",
662 "TensorArrayGatherV2",
663 "TensorArrayGatherV3",
664 "TensorArrayGrad",
665 "TensorArrayGradV2",
666 "TensorArrayGradV3",
667 "TensorArrayGradWithShape",
668 "TensorArrayPack",
669 "TensorArrayRead",
670 "TensorArrayReadV2",
671 "TensorArrayReadV3",
672 "TensorArrayScatter",
673 "TensorArrayScatterV2",
674 "TensorArrayScatterV3",
675 "TensorArraySize",
676 "TensorArraySizeV2",
677 "TensorArraySizeV3",
678 "TensorArraySplit",
679 "TensorArraySplitV2",
680 "TensorArraySplitV3",
681 "TensorArrayUnpack",
682 "TensorArrayV2",
683 "TensorArrayV3",
684 "TensorArrayWrite",
685 "TensorArrayWriteV2",
686 "TensorArrayWriteV3",
687 "TensorListConcat",
688 "TensorListConcatLists",
689 "TensorListConcatV2",
690 "TensorListElementShape",
691 "TensorListFromTensor",
692 "TensorListGather",
693 "TensorListGetItem",
694 "TensorListLength",
695 "TensorListPopBack",
696 "TensorListPushBack",
697 "TensorListPushBackBatch",
698 "TensorListReserve",
699 "TensorListResize",
700 "TensorListScatter",
701 "TensorListScatterIntoExistingList",
702 "TensorListScatterV2",
703 "TensorListSetItem",
704 "TensorListSplit",
705 "TensorListStack",
706 "TensorMapErase",
707 "TensorMapHasKey",
708 "TensorMapInsert",
709 "TensorMapLookup",
710 "TensorMapSize",
711 "TensorMapStackKeys",
712 "TensorScatterAdd",
713 "TensorScatterMax",
714 "TensorScatterMin",
715 "TensorScatterSub",
716 "TensorScatterUpdate",
717 "TensorSliceDataset",
718 "TensorStridedSliceUpdate",
719 "Tile",
720 "TileGrad",
721 "Timestamp",
722 "TopK",
723 "TopKV2",
724 "Transpose",
725 "TruncateDiv",
726 "TruncatedNormal",
727 "UnicodeDecode",
728 "UnicodeDecodeWithOffsets",
729 "UnicodeEncode",
730 "UnicodeTranscode",
731 "Unique",
732 "UniqueV2",
733 "UniqueWithCounts",
734 "UniqueWithCountsV2",
735 "Unpack",
736 "UnsortedSegmentJoin",
737 "UnsortedSegmentMax",
738 "UnsortedSegmentMin",
739 "UnsortedSegmentProd",
740 "UnsortedSegmentSum",
741 "UnwrapDatasetVariant",
742 "UpperBound",
743 "VarHandleOp",
744 "VarIsInitializedOp",
745 "Variable",
746 "VariableShape",
747 "VariableV2",
748 "Where",
749 "WrapDatasetVariant",
750 "WriteFile",
751 "Xdivy",
752 "Xlog1py",
753 "Xlogy",
754 "ZerosLike",
755 "_Arg",
756 "_ArrayToList",
757 "_DeviceArg",
758 "_DeviceRetval",
759 "_FusedConv2D",
760 "_HostCast",
761 "_HostRecv",
762 "_HostSend",
763 "_ListToArray",
764 "_ParallelConcatStart",
765 "_ParallelConcatUpdate",
766 "_ReadVariablesOp",
767 "_Recv",
768 "_Retval",
769 "_Send",
770 "_SwitchN",
771 "_VarHandlesOp",
772 // go/keep-sorted end
773 });
774 // LINT.ThenChange(//tensorflow/lite/g3doc/guide/op_select_allowlist.md)
775
776 return *allowlisted_flex_ops;
777 // Prevent lint error about this function being too long. This function
778 // is a set of ops, and making it shorter won't help readbility.
779 // NOLINTNEXTLINE
780 }
781
GetTFTextFlexAllowlist()782 const std::set<std::string>& GetTFTextFlexAllowlist() {
783 // LINT.IfChange
784 static const std::set<std::string>* tftext_flex_ops =
785 new std::set<std::string>({
786 "CaseFoldUTF8",
787 "ConstrainedSequence",
788 "MaxSpanningTree",
789 "NormalizeUTF8",
790 "NormalizeUTF8WithOffsetsMap",
791 "RegexSplitWithOffsets",
792 "RougeL",
793 "SentenceFragments",
794 "SentencepieceOp",
795 "SentencepieceTokenizeOp",
796 "SentencepieceTokenizeWithOffsetsOp",
797 "SentencepieceDetokenizeOp",
798 "SentencepieceVocabSizeOp",
799 "SplitMergeTokenizeWithOffsets",
800 "TFText>NgramsStringJoin",
801 "TFText>WhitespaceTokenizeWithOffsetsV2",
802 "TokenizerFromLogits",
803 "UnicodeScriptTokenizeWithOffsets",
804 "WhitespaceTokenizeWithOffsets",
805 "WordpieceTokenizeWithOffsets",
806 });
807 // LINT.ThenChange(//tensorflow/lite/g3doc/guide/op_select_allowlist.md)
808
809 return *tftext_flex_ops;
810 }
811
812 // Allow the tf.text ops if they are registered in the global op registry.
IsAllowedTFTextOpForFlex(const std::string & op_name)813 bool IsAllowedTFTextOpForFlex(const std::string& op_name) {
814 if (GetTFTextFlexAllowlist().count(op_name) == 0) return false;
815 return tensorflow::OpRegistry::Global()->LookUp(op_name) != nullptr;
816 }
817
GetSentencePieceFlexAllowlist()818 const std::set<std::string>& GetSentencePieceFlexAllowlist() {
819 // LINT.IfChange
820 static const std::set<std::string>* sentencepiece_flex_ops =
821 new std::set<std::string>({
822 "SentencepieceGetPieceSize",
823 "SentencepiecePieceToId",
824 "SentencepieceIdToPiece",
825 "SentencepieceEncodeDense",
826 "SentencepieceEncodeSparse",
827 "SentencepieceDecode",
828 });
829 // LINT.ThenChange(//tensorflow/lite/g3doc/guide/op_select_allowlist.md)
830
831 return *sentencepiece_flex_ops;
832 }
833
834 // Allow the sentencepiece ops if they are registered in the global op registry.
IsAllowedSentencePieceOpForFlex(const std::string & op_name)835 bool IsAllowedSentencePieceOpForFlex(const std::string& op_name) {
836 if (GetSentencePieceFlexAllowlist().count(op_name) == 0) return false;
837 return tensorflow::OpRegistry::Global()->LookUp(op_name) != nullptr;
838 }
839
IsAllowlistedFlexOp(const std::string & tensorflow_op_name)840 bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name) {
841 if (GetFlexAllowlist().count(tensorflow_op_name) != 0) return true;
842
843 // Check if the op is an allowlisted tf.text or sentencepiece op.
844 return IsAllowedTFTextOpForFlex(tensorflow_op_name) ||
845 IsAllowedSentencePieceOpForFlex(tensorflow_op_name);
846 }
847
848 } // namespace flex
849 } // namespace tflite
850