xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/versioning/op_signature_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/versioning/op_signature.h"
16 
17 #include <cstring>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 #include "tensorflow/core/platform/resource_loader.h"
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/model_builder.h"
26 
27 namespace tflite {
28 
29 // StubTfLiteContext is a TfLiteContext which has 3 nodes as the followings.
30 // dummyAdd -> target op -> dummyAdd
31 class StubTfLiteContext : public TfLiteContext {
32  public:
StubTfLiteContext(const int builtin_code,const int op_version,const int num_inputs)33   StubTfLiteContext(const int builtin_code, const int op_version,
34                     const int num_inputs)
35       : TfLiteContext({0}) {
36     // Stub execution plan
37     exec_plan_ = TfLiteIntArrayCreate(3);
38     for (int i = 0; i < 3; ++i) exec_plan_->data[i] = i;
39 
40     int tensor_no = 0;
41     std::memset(nodes_, 0, sizeof(nodes_));
42     std::memset(registrations_, 0, sizeof(registrations_));
43 
44     // Node 0, dummyAdd
45     nodes_[0].inputs = TfLiteIntArrayCreate(1);
46     nodes_[0].inputs->data[0] = tensor_no++;
47     nodes_[0].outputs = TfLiteIntArrayCreate(1);
48     nodes_[0].outputs->data[0] = tensor_no;
49     nodes_[0].builtin_data = nullptr;
50 
51     // Node 1, target op
52     nodes_[1].inputs = TfLiteIntArrayCreate(num_inputs);
53     for (int i = 0; i < num_inputs; i++) {
54       nodes_[1].inputs->data[i] = tensor_no++;
55     }
56     nodes_[1].outputs = TfLiteIntArrayCreate(1);
57     nodes_[1].outputs->data[0] = tensor_no;
58     nodes_[1].builtin_data = malloc(1024);
59     std::memset(nodes_[1].builtin_data, 0, 1024);
60 
61     // Node 2, dummyAdd
62     nodes_[2].inputs = TfLiteIntArrayCreate(1);
63     nodes_[2].inputs->data[0] = tensor_no++;
64     nodes_[2].outputs = TfLiteIntArrayCreate(1);
65     nodes_[2].outputs->data[0] = tensor_no++;
66     nodes_[2].builtin_data = nullptr;
67 
68     // Creates tensors of 4d float32
69     tensors_.resize(tensor_no);
70     for (size_t i = 0; i < tensors_.size(); i++) {
71       std::memset(&tensors_[i], 0, sizeof(tensors_[i]));
72       tensors_[i].buffer_handle = kTfLiteNullBufferHandle;
73       tensors_[i].type = kTfLiteFloat32;
74       tensors_[i].dims = TfLiteIntArrayCreate(4);
75       for (int d = 0; d < 4; d++) {
76         tensors_[i].dims->data[d] = 1;
77       }
78     }
79     tensors = tensors_.data();
80     tensors_size = tensors_.size();
81 
82     // Creates registrations
83     registrations_[0].builtin_code = kTfLiteBuiltinAdd;
84     registrations_[1].builtin_code = builtin_code;
85     registrations_[1].version = op_version;
86     registrations_[2].builtin_code = kTfLiteBuiltinAdd;
87 
88     this->GetExecutionPlan = StubGetExecutionPlan;
89     this->GetNodeAndRegistration = StubGetNodeAndRegistration;
90   }
~StubTfLiteContext()91   ~StubTfLiteContext() {
92     for (auto& node : nodes_) {
93       TfLiteIntArrayFree(node.inputs);
94       TfLiteIntArrayFree(node.outputs);
95       if (node.builtin_data) {
96         free(node.builtin_data);
97       }
98     }
99     for (auto& tensor : tensors_) {
100       TfLiteIntArrayFree(tensor.dims);
101     }
102     TfLiteIntArrayFree(exec_plan_);
103   }
104 
exec_plan() const105   TfLiteIntArray* exec_plan() const { return exec_plan_; }
node()106   TfLiteNode* node() { return &nodes_[1]; }
registration()107   TfLiteRegistration* registration() { return &registrations_[1]; }
node(int node_index)108   TfLiteNode* node(int node_index) { return &nodes_[node_index]; }
registration(int reg_index)109   TfLiteRegistration* registration(int reg_index) {
110     return &registrations_[reg_index];
111   }
tensor(int tensor_index)112   TfLiteTensor* tensor(int tensor_index) { return &tensors_[tensor_index]; }
113 
114  private:
StubGetExecutionPlan(TfLiteContext * context,TfLiteIntArray ** execution_plan)115   static TfLiteStatus StubGetExecutionPlan(TfLiteContext* context,
116                                            TfLiteIntArray** execution_plan) {
117     StubTfLiteContext* stub = reinterpret_cast<StubTfLiteContext*>(context);
118     *execution_plan = stub->exec_plan();
119     return kTfLiteOk;
120   }
121 
StubGetNodeAndRegistration(TfLiteContext * context,int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)122   static TfLiteStatus StubGetNodeAndRegistration(
123       TfLiteContext* context, int node_index, TfLiteNode** node,
124       TfLiteRegistration** registration) {
125     StubTfLiteContext* stub = reinterpret_cast<StubTfLiteContext*>(context);
126     *node = stub->node(node_index);
127     *registration = stub->registration(node_index);
128     return kTfLiteOk;
129   }
130 
131   TfLiteIntArray* exec_plan_;
132   TfLiteNode nodes_[3];
133   TfLiteRegistration registrations_[3];
134   std::vector<TfLiteTensor> tensors_;
135 };
136 
TEST(GetOpSignature,FlatBufferModel)137 TEST(GetOpSignature, FlatBufferModel) {
138   const std::string& full_path =
139       tensorflow::GetDataDependencyFilepath("tensorflow/lite/testdata/add.bin");
140   auto fb_model = FlatBufferModel::BuildFromFile(full_path.data());
141   ASSERT_TRUE(fb_model);
142   auto model = fb_model->GetModel();
143   auto subgraphs = model->subgraphs();
144   const SubGraph* subgraph = subgraphs->Get(0);
145   const Operator* op1 = subgraph->operators()->Get(0);
146   const OperatorCode* op_code1 =
147       model->operator_codes()->Get(op1->opcode_index());
148   OpSignature op_sig = GetOpSignature(op_code1, op1, subgraph, model);
149   EXPECT_EQ(op_sig.op, BuiltinOperator_ADD);
150   EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32);
151   EXPECT_EQ(op_sig.inputs[0].dims.size(), 4);
152   EXPECT_FALSE(op_sig.inputs[0].is_const);
153   EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic);
154   EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32);
155   EXPECT_FALSE(op_sig.outputs[0].is_const);
156   EXPECT_EQ(op_sig.outputs[0].dims.size(), 4);
157   EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic);
158   EXPECT_NE(op_sig.builtin_data, nullptr);
159   EXPECT_EQ(op_sig.version, 1);
160   free(op_sig.builtin_data);
161 
162   const Operator* op2 = subgraph->operators()->Get(1);
163   const OperatorCode* op_code2 =
164       model->operator_codes()->Get(op2->opcode_index());
165   op_sig = GetOpSignature(op_code2, op2, subgraph, model);
166   EXPECT_EQ(op_sig.op, BuiltinOperator_ADD);
167   EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32);
168   EXPECT_EQ(op_sig.inputs[0].dims.size(), 4);
169   EXPECT_FALSE(op_sig.inputs[0].is_const);
170   EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic);
171   EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32);
172   EXPECT_FALSE(op_sig.outputs[0].is_const);
173   EXPECT_EQ(op_sig.outputs[0].dims.size(), 4);
174   EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic);
175   EXPECT_NE(op_sig.builtin_data, nullptr);
176   EXPECT_EQ(op_sig.version, 1);
177   free(op_sig.builtin_data);
178 
179   const std::string& full_path3 = tensorflow::GetDataDependencyFilepath(
180       "tensorflow/lite/testdata/multi_signatures.bin");
181   auto fb_model3 = FlatBufferModel::BuildFromFile(full_path3.data());
182   ASSERT_TRUE(fb_model3);
183   auto model3 = fb_model3->GetModel();
184   auto subgraphs3 = model3->subgraphs();
185   const SubGraph* subgraph3 = subgraphs3->Get(0);
186   const Operator* op3 = subgraph3->operators()->Get(0);
187   const OperatorCode* op_code3 =
188       model3->operator_codes()->Get(op3->opcode_index());
189   op_sig = GetOpSignature(op_code3, op3, subgraph3, model3);
190   EXPECT_EQ(op_sig.op, BuiltinOperator_ADD);
191   EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32);
192   EXPECT_EQ(op_sig.inputs[0].dims.size(), 1);
193   EXPECT_FALSE(op_sig.inputs[0].is_const);
194   EXPECT_TRUE(op_sig.inputs[0].is_shape_dynamic);
195   EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32);
196   EXPECT_FALSE(op_sig.outputs[0].is_const);
197   EXPECT_EQ(op_sig.outputs[0].dims.size(), 1);
198   EXPECT_TRUE(op_sig.outputs[0].is_shape_dynamic);
199   EXPECT_NE(op_sig.builtin_data, nullptr);
200   EXPECT_EQ(op_sig.version, 1);
201   free(op_sig.builtin_data);
202 }
203 
TEST(GetOpSignature,TfLiteContext)204 TEST(GetOpSignature, TfLiteContext) {
205   auto context = std::make_unique<StubTfLiteContext>(kTfLiteBuiltinAdd,
206                                                      /*op_version=*/1,
207                                                      /*num_inputs=*/4);
208   OpSignature op_sig =
209       GetOpSignature(context.get(), context->node(), context->registration());
210   EXPECT_EQ(op_sig.op, BuiltinOperator_ADD);
211   EXPECT_EQ(op_sig.inputs[0].type, kTfLiteFloat32);
212   EXPECT_EQ(op_sig.inputs[0].dims.size(), 4);
213   EXPECT_FALSE(op_sig.inputs[0].is_const);
214   EXPECT_FALSE(op_sig.inputs[0].is_shape_dynamic);
215   EXPECT_EQ(op_sig.outputs[0].type, kTfLiteFloat32);
216   EXPECT_FALSE(op_sig.outputs[0].is_const);
217   EXPECT_EQ(op_sig.outputs[0].dims.size(), 4);
218   EXPECT_FALSE(op_sig.outputs[0].is_shape_dynamic);
219   EXPECT_NE(op_sig.builtin_data, nullptr);
220   EXPECT_EQ(op_sig.version, 1);
221 }
222 
223 }  // namespace tflite
224