1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/versioning/op_signature.h"
16
17 #include <cstring>
18 #include <memory>
19 #include <string>
20 #include <vector>
21
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/platform/resource_loader.h"
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/model_builder.h"
26
27 namespace tflite {
28
29 // StubTfLiteContext is a TfLiteContext which has 3 nodes as the followings.
30 // dummyAdd -> target op -> dummyAdd
31 class StubTfLiteContext : public TfLiteContext {
32 public:
StubTfLiteContext(const int builtin_code,const int op_version,const int num_inputs)33 StubTfLiteContext(const int builtin_code, const int op_version,
34 const int num_inputs)
35 : TfLiteContext({0}) {
36 // Stub execution plan
37 exec_plan_ = TfLiteIntArrayCreate(3);
38 for (int i = 0; i < 3; ++i) exec_plan_->data[i] = i;
39
40 int tensor_no = 0;
41 std::memset(nodes_, 0, sizeof(nodes_));
42 std::memset(registrations_, 0, sizeof(registrations_));
43
44 // Node 0, dummyAdd
45 nodes_[0].inputs = TfLiteIntArrayCreate(1);
46 nodes_[0].inputs->data[0] = tensor_no++;
47 nodes_[0].outputs = TfLiteIntArrayCreate(1);
48 nodes_[0].outputs->data[0] = tensor_no;
49 nodes_[0].builtin_data = nullptr;
50
51 // Node 1, target op
52 nodes_[1].inputs = TfLiteIntArrayCreate(num_inputs);
53 for (int i = 0; i < num_inputs; i++) {
54 nodes_[1].inputs->data[i] = tensor_no++;
55 }
56 nodes_[1].outputs = TfLiteIntArrayCreate(1);
57 nodes_[1].outputs->data[0] = tensor_no;
58 nodes_[1].builtin_data = malloc(1024);
59 std::memset(nodes_[1].builtin_data, 0, 1024);
60
61 // Node 2, dummyAdd
62 nodes_[2].inputs = TfLiteIntArrayCreate(1);
63 nodes_[2].inputs->data[0] = tensor_no++;
64 nodes_[2].outputs = TfLiteIntArrayCreate(1);
65 nodes_[2].outputs->data[0] = tensor_no++;
66 nodes_[2].builtin_data = nullptr;
67
68 // Creates tensors of 4d float32
69 tensors_.resize(tensor_no);
70 for (size_t i = 0; i < tensors_.size(); i++) {
71 std::memset(&tensors_[i], 0, sizeof(tensors_[i]));
72 tensors_[i].buffer_handle = kTfLiteNullBufferHandle;
73 tensors_[i].type = kTfLiteFloat32;
74 tensors_[i].dims = TfLiteIntArrayCreate(4);
75 for (int d = 0; d < 4; d++) {
76 tensors_[i].dims->data[d] = 1;
77 }
78 }
79 tensors = tensors_.data();
80 tensors_size = tensors_.size();
81
82 // Creates registrations
83 registrations_[0].builtin_code = kTfLiteBuiltinAdd;
84 registrations_[1].builtin_code = builtin_code;
85 registrations_[1].version = op_version;
86 registrations_[2].builtin_code = kTfLiteBuiltinAdd;
87
88 this->GetExecutionPlan = StubGetExecutionPlan;
89 this->GetNodeAndRegistration = StubGetNodeAndRegistration;
90 }
~StubTfLiteContext()91 ~StubTfLiteContext() {
92 for (auto& node : nodes_) {
93 TfLiteIntArrayFree(node.inputs);
94 TfLiteIntArrayFree(node.outputs);
95 if (node.builtin_data) {
96 free(node.builtin_data);
97 }
98 }
99 for (auto& tensor : tensors_) {
100 TfLiteIntArrayFree(tensor.dims);
101 }
102 TfLiteIntArrayFree(exec_plan_);
103 }
104
exec_plan() const105 TfLiteIntArray* exec_plan() const { return exec_plan_; }
node()106 TfLiteNode* node() { return &nodes_[1]; }
registration()107 TfLiteRegistration* registration() { return ®istrations_[1]; }
node(int node_index)108 TfLiteNode* node(int node_index) { return &nodes_[node_index]; }
registration(int reg_index)109 TfLiteRegistration* registration(int reg_index) {
110 return ®istrations_[reg_index];
111 }
tensor(int tensor_index)112 TfLiteTensor* tensor(int tensor_index) { return &tensors_[tensor_index]; }
113
114 private:
StubGetExecutionPlan(TfLiteContext * context,TfLiteIntArray ** execution_plan)115 static TfLiteStatus StubGetExecutionPlan(TfLiteContext* context,
116 TfLiteIntArray** execution_plan) {
117 StubTfLiteContext* stub = reinterpret_cast<StubTfLiteContext*>(context);
118 *execution_plan = stub->exec_plan();
119 return kTfLiteOk;
120 }
121
StubGetNodeAndRegistration(TfLiteContext * context,int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)122 static TfLiteStatus StubGetNodeAndRegistration(
123 TfLiteContext* context, int node_index, TfLiteNode** node,
124 TfLiteRegistration** registration) {
125 StubTfLiteContext* stub = reinterpret_cast<StubTfLiteContext*>(context);
126 *node = stub->node(node_index);
127 *registration = stub->registration(node_index);
128 return kTfLiteOk;
129 }
130
131 TfLiteIntArray* exec_plan_;
132 TfLiteNode nodes_[3];
133 TfLiteRegistration registrations_[3];
134 std::vector<TfLiteTensor> tensors_;
135 };
136
TEST(GetOpSignature,FlatBufferModel)137 TEST(GetOpSignature, FlatBufferModel) {
138 const std::string& full_path =
139 tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin");
140 auto fb_model = FlatBufferModel::BuildFromFile(full_path.data());
141 ASSERT_TRUE(fb_model);
142 auto model = fb_model->GetModel();
143 auto subgraphs = model->subgraphs();
144 const SubGraph* subgraph = subgraphs->Get(0);
145 const Operator* op1 = subgraph->operators()->Get(0);
146 const OperatorCode* op_code1 =
147 model->operator_codes()->Get(op1->opcode_index());
148 OpSignature op_sig = GetOpSignature(op_code1, op1, subgraph, model);
149 EXPECT_EQ(op_sig.op, BuiltinOperator_ADD);
150 EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32);
151 EXPECT_EQ(op_sig.inputs[0].dims.size(), 4);
152 EXPECT_FALSE(op_sig.inputs[0].is_const);
153 EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic);
154 EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32);
155 EXPECT_FALSE(op_sig.outputs[0].is_const);
156 EXPECT_EQ(op_sig.outputs[0].dims.size(), 4);
157 EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic);
158 EXPECT_NE(op_sig.builtin_data, nullptr);
159 EXPECT_EQ(op_sig.version, 1);
160 free(op_sig.builtin_data);
161
162 const Operator* op2 = subgraph->operators()->Get(1);
163 const OperatorCode* op_code2 =
164 model->operator_codes()->Get(op2->opcode_index());
165 op_sig = GetOpSignature(op_code2, op2, subgraph, model);
166 EXPECT_EQ(op_sig.op, BuiltinOperator_ADD);
167 EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32);
168 EXPECT_EQ(op_sig.inputs[0].dims.size(), 4);
169 EXPECT_FALSE(op_sig.inputs[0].is_const);
170 EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic);
171 EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32);
172 EXPECT_FALSE(op_sig.outputs[0].is_const);
173 EXPECT_EQ(op_sig.outputs[0].dims.size(), 4);
174 EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic);
175 EXPECT_NE(op_sig.builtin_data, nullptr);
176 EXPECT_EQ(op_sig.version, 1);
177 free(op_sig.builtin_data);
178
179 const std::string& full_path3 = tensorflow::GetDataDependencyFilepath(
180 "tensorflow/lite/testdata/multi_signatures.bin");
181 auto fb_model3 = FlatBufferModel::BuildFromFile(full_path3.data());
182 ASSERT_TRUE(fb_model3);
183 auto model3 = fb_model3->GetModel();
184 auto subgraphs3 = model3->subgraphs();
185 const SubGraph* subgraph3 = subgraphs3->Get(0);
186 const Operator* op3 = subgraph3->operators()->Get(0);
187 const OperatorCode* op_code3 =
188 model3->operator_codes()->Get(op3->opcode_index());
189 op_sig = GetOpSignature(op_code3, op3, subgraph3, model3);
190 EXPECT_EQ(op_sig.op, BuiltinOperator_ADD);
191 EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32);
192 EXPECT_EQ(op_sig.inputs[0].dims.size(), 1);
193 EXPECT_FALSE(op_sig.inputs[0].is_const);
194 EXPECT_TRUE(op_sig.inputs[0].is_shape_dynamic);
195 EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32);
196 EXPECT_FALSE(op_sig.outputs[0].is_const);
197 EXPECT_EQ(op_sig.outputs[0].dims.size(), 1);
198 EXPECT_TRUE(op_sig.outputs[0].is_shape_dynamic);
199 EXPECT_NE(op_sig.builtin_data, nullptr);
200 EXPECT_EQ(op_sig.version, 1);
201 free(op_sig.builtin_data);
202 }
203
TEST(GetOpSignature,TfLiteContext)204 TEST(GetOpSignature, TfLiteContext) {
205 auto context = std::make_unique<StubTfLiteContext>(kTfLiteBuiltinAdd,
206 /*op_version=*/1,
207 /*num_inputs=*/4);
208 OpSignature op_sig =
209 GetOpSignature(context.get(), context->node(), context->registration());
210 EXPECT_EQ(op_sig.op, BuiltinOperator_ADD);
211 EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32);
212 EXPECT_EQ(op_sig.inputs[0].dims.size(), 4);
213 EXPECT_FALSE(op_sig.inputs[0].is_const);
214 EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic);
215 EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32);
216 EXPECT_FALSE(op_sig.outputs[0].is_const);
217 EXPECT_EQ(op_sig.outputs[0].dims.size(), 4);
218 EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic);
219 EXPECT_NE(op_sig.builtin_data, nullptr);
220 EXPECT_EQ(op_sig.version, 1);
221 }
222
223 } // namespace tflite
224