xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/versioning/op_signature.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_
16 #define TENSORFLOW_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_
17 
18 #include <string>
19 
20 #include "tensorflow/lite/c/c_api_types.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/schema/schema_generated.h"
23 
24 namespace tflite {
25 
26 // OpSignature contains operator parameters for version functions.
27 typedef struct {
28   TfLiteType type;
29   std::vector<int32_t> dims;
30   bool is_const;
31   bool is_shape_dynamic;
32 } OpSignatureTensorSpec;
33 
34 typedef struct {
35   BuiltinOperator op;
36   std::vector<OpSignatureTensorSpec> inputs;
37   std::vector<OpSignatureTensorSpec> outputs;
38   void* builtin_data;
39   int version;
40   const void* custom_initial_data;
41   std::string custom_name;
42   union {
43     struct {
44       bool is_per_channel_quantized;
45       bool is_grouped_convolution;
46     } conv_2d;
47     struct {
48       bool is_per_channel_quantized;
49     } depthwise_conv_2d;
50     struct {
51       // TODO(b/156530611): Make this global when more ops support sparse
52       // computation.
53       bool sparse_weight;
54     } fully_connected;
55     struct {
56       float input1_scale;
57       float input2_scale;
58       float output_scale;
59     } mul;
60     struct {
61       int32_t num_dims;
62     } strided_slice;
63     struct {
64       bool input_quantized;
65     } abs;
66     struct {
67       bool is_per_channel_quantized;
68     } dequantize;
69     struct {
70       bool is_per_channel_quantized;
71     } quantize;
72   } ext_options;
73 } OpSignature;
74 
75 // Generate OpSignature with the given OperatorCode, Operator and Tensors (from
76 // SubGraph). The OpSignature will be used by GetBuiltinOperatorVersion() and
77 // mostly input and output tensor types are enough to figure out op version.
78 // But some ops (DEPTHWISE_CONV_2D,  FULLY_CONNECTED, ...) require to pass their
79 // options to decide op version.
80 //
81 // WARNING: The caller is responsible to free the allocated
82 // OpSignature.builtin_data memory.
83 OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
84                            const SubGraph* subgraph, const Model* model);
85 
86 // Generate OpSignature with the given TfLiteContext, TfLiteNode and
87 // TfLiteRegistration.
88 // The function can be used by a compatibility checker of a delegate such as
89 // TFLiteOperationParser::IsSupported() in the GPU delegate.
90 OpSignature GetOpSignature(const TfLiteContext* context, const TfLiteNode* node,
91                            const TfLiteRegistration* registration);
92 }  // namespace tflite
93 #endif  // TENSORFLOW_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_
94