xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/schema/upgrade_schema_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Testing for updating TensorFlow lite schema."""
16
17import json
18import tempfile
19from tensorflow.lite.schema import upgrade_schema as upgrade_schema_lib
20from tensorflow.python.framework import test_util
21from tensorflow.python.platform import test as test_lib
22
23EMPTY_TEST_SCHEMA_V1 = {
24    "version": 1,
25    "operator_codes": [],
26    "subgraphs": [],
27}
28
29EMPTY_TEST_SCHEMA_V3 = {
30    "version": 3,
31    "operator_codes": [],
32    "subgraphs": [],
33    "buffers": [{
34        "data": []
35    }]
36}
37
38TEST_SCHEMA_V0 = {
39    "operator_codes": [],
40    "tensors": [],
41    "inputs": [],
42    "outputs": [],
43    "operators": [],
44    "version": 0
45}
46
47TEST_SCHEMA_V3 = {
48    "operator_codes": [],
49    "buffers": [{
50        "data": []
51    }],
52    "subgraphs": [{
53        "tensors": [],
54        "inputs": [],
55        "outputs": [],
56        "operators": [],
57    }],
58    "version":
59        3
60}
61
62FULL_TEST_SCHEMA_V1 = {
63    "version":
64        1,
65    "operator_codes": [
66        {
67            "builtin_code": "CONVOLUTION"
68        },
69        {
70            "builtin_code": "DEPTHWISE_CONVOLUTION"
71        },
72        {
73            "builtin_code": "AVERAGE_POOL"
74        },
75        {
76            "builtin_code": "MAX_POOL"
77        },
78        {
79            "builtin_code": "L2_POOL"
80        },
81        {
82            "builtin_code": "SIGMOID"
83        },
84        {
85            "builtin_code": "L2NORM"
86        },
87        {
88            "builtin_code": "LOCAL_RESPONSE_NORM"
89        },
90        {
91            "builtin_code": "ADD"
92        },
93        {
94            "builtin_code": "Basic_RNN"
95        },
96    ],
97    "subgraphs": [{
98        "operators": [
99            {
100                "builtin_options_type": "PoolOptions"
101            },
102            {
103                "builtin_options_type": "DepthwiseConvolutionOptions"
104            },
105            {
106                "builtin_options_type": "ConvolutionOptions"
107            },
108            {
109                "builtin_options_type": "LocalResponseNormOptions"
110            },
111            {
112                "builtin_options_type": "BasicRNNOptions"
113            },
114        ],
115    }],
116    "description":
117        "",
118}
119
120FULL_TEST_SCHEMA_V3 = {
121    "version":
122        3,
123    "operator_codes": [
124        {
125            "builtin_code": "CONV_2D"
126        },
127        {
128            "builtin_code": "DEPTHWISE_CONV_2D"
129        },
130        {
131            "builtin_code": "AVERAGE_POOL_2D"
132        },
133        {
134            "builtin_code": "MAX_POOL_2D"
135        },
136        {
137            "builtin_code": "L2_POOL_2D"
138        },
139        {
140            "builtin_code": "LOGISTIC"
141        },
142        {
143            "builtin_code": "L2_NORMALIZATION"
144        },
145        {
146            "builtin_code": "LOCAL_RESPONSE_NORMALIZATION"
147        },
148        {
149            "builtin_code": "ADD"
150        },
151        {
152            "builtin_code": "RNN"
153        },
154    ],
155    "subgraphs": [{
156        "operators": [
157            {
158                "builtin_options_type": "Pool2DOptions"
159            },
160            {
161                "builtin_options_type": "DepthwiseConv2DOptions"
162            },
163            {
164                "builtin_options_type": "Conv2DOptions"
165            },
166            {
167                "builtin_options_type": "LocalResponseNormalizationOptions"
168            },
169            {
170                "builtin_options_type": "RNNOptions"
171            },
172        ],
173    }],
174    "description":
175        "",
176    "buffers": [{
177        "data": []
178    }]
179}
180
181BUFFER_TEST_V2 = {
182    "operator_codes": [],
183    "buffers": [],
184    "subgraphs": [{
185        "tensors": [
186            {
187                "data_buffer": [1, 2, 3, 4]
188            },
189            {
190                "data_buffer": [1, 2, 3, 4, 5, 6, 7, 8]
191            },
192            {
193                "data_buffer": []
194            },
195        ],
196        "inputs": [],
197        "outputs": [],
198        "operators": [],
199    }],
200    "version":
201        2
202}
203
204BUFFER_TEST_V3 = {
205    "operator_codes": [],
206    "subgraphs": [{
207        "tensors": [
208            {
209                "buffer": 1
210            },
211            {
212                "buffer": 2
213            },
214            {
215                "buffer": 0
216            },
217        ],
218        "inputs": [],
219        "outputs": [],
220        "operators": [],
221    }],
222    "buffers": [
223        {
224            "data": []
225        },
226        {
227            "data": [1, 2, 3, 4]
228        },
229        {
230            "data": [1, 2, 3, 4, 5, 6, 7, 8]
231        },
232    ],
233    "version":
234        3
235}
236
237
238def JsonDumpAndFlush(data, fp):
239  """Write the dictionary `data` to a JSON file `fp` (and flush).
240
241  Args:
242    data: in a dictionary that is JSON serializable.
243    fp: File-like object
244  """
245  json.dump(data, fp)
246  fp.flush()
247
248
249class TestSchemaUpgrade(test_util.TensorFlowTestCase):
250
251  def testNonExistentFile(self):
252    converter = upgrade_schema_lib.Converter()
253    _, non_existent = tempfile.mkstemp(suffix=".json")  # safe to ignore fd
254    with self.assertRaisesRegex(IOError, "No such file or directory"):
255      converter.Convert(non_existent, non_existent)
256
257  def testInvalidExtension(self):
258    converter = upgrade_schema_lib.Converter()
259    _, invalid_extension = tempfile.mkstemp(suffix=".foo")  # safe to ignore fd
260    with self.assertRaisesRegex(ValueError, "Invalid extension on input"):
261      converter.Convert(invalid_extension, invalid_extension)
262    with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json:
263      JsonDumpAndFlush(EMPTY_TEST_SCHEMA_V1, in_json)
264      with self.assertRaisesRegex(ValueError, "Invalid extension on output"):
265        converter.Convert(in_json.name, invalid_extension)
266
267  def CheckConversion(self, data_old, data_expected):
268    """Given a data dictionary, test upgrading to current version.
269
270    Args:
271        data_old: TFLite model as a dictionary (arbitrary version).
272        data_expected: TFLite model as a dictionary (upgraded).
273    """
274    converter = upgrade_schema_lib.Converter()
275    with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as in_json, \
276            tempfile.NamedTemporaryFile(
277                suffix=".json", mode="w+") as out_json, \
278            tempfile.NamedTemporaryFile(
279                suffix=".bin", mode="w+b") as out_bin, \
280            tempfile.NamedTemporaryFile(
281                suffix=".tflite", mode="w+b") as out_tflite:
282      JsonDumpAndFlush(data_old, in_json)
283      # Test JSON output
284      converter.Convert(in_json.name, out_json.name)
285      # Test binary output
286      # Convert to .tflite  and then to .bin and check if binary is equal
287      converter.Convert(in_json.name, out_tflite.name)
288      converter.Convert(out_tflite.name, out_bin.name)
289      self.assertEqual(
290          open(out_bin.name, "rb").read(),
291          open(out_tflite.name, "rb").read())
292      # Test that conversion actually produced successful new json.
293      converted_schema = json.load(out_json)
294      self.assertEqual(converted_schema, data_expected)
295
296  def testAlreadyUpgraded(self):
297    """A file already at version 3 should stay at version 3."""
298    self.CheckConversion(EMPTY_TEST_SCHEMA_V3, EMPTY_TEST_SCHEMA_V3)
299    self.CheckConversion(TEST_SCHEMA_V3, TEST_SCHEMA_V3)
300    self.CheckConversion(BUFFER_TEST_V3, BUFFER_TEST_V3)
301
302  # Disable this while we have incorrectly versioned structures around.
303  # def testV0Upgrade_IntroducesSubgraphs(self):
304  #   """V0 did not have subgraphs; check to make sure they get introduced."""
305  #   self.CheckConversion(TEST_SCHEMA_V0, TEST_SCHEMA_V3)
306
307  def testV1Upgrade_RenameOps(self):
308    """V1 had many different names for ops; check to make sure they rename."""
309    self.CheckConversion(EMPTY_TEST_SCHEMA_V1, EMPTY_TEST_SCHEMA_V3)
310    self.CheckConversion(FULL_TEST_SCHEMA_V1, FULL_TEST_SCHEMA_V3)
311
312  def testV2Upgrade_CreateBuffers(self):
313    """V2 did not have buffers; check to make sure they are created."""
314    self.CheckConversion(BUFFER_TEST_V2, BUFFER_TEST_V3)
315
316
317if __name__ == "__main__":
318  test_lib.main()
319