xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/proto/proto_op_test_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# =============================================================================
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# =============================================================================
16"""Test case base for testing proto operations."""
17
18# Python3 preparedness imports.
19import ctypes as ct
20import os
21
22from tensorflow.core.framework import types_pb2
23from tensorflow.python.kernel_tests.proto import test_example_pb2
24from tensorflow.python.platform import test
25
26
27class ProtoOpTestBase(test.TestCase):
28  """Base class for testing proto decoding and encoding ops."""
29
30  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
31    super(ProtoOpTestBase, self).__init__(methodName)
32    lib = os.path.join(os.path.dirname(__file__), "libtestexample.so")
33    if os.path.isfile(lib):
34      ct.cdll.LoadLibrary(lib)
35
36  @staticmethod
37  def named_parameters(extension=True):
38    parameters = [("defaults", ProtoOpTestBase.defaults_test_case()),
39                  ("minmax", ProtoOpTestBase.minmax_test_case()),
40                  ("nested", ProtoOpTestBase.nested_test_case()),
41                  ("optional", ProtoOpTestBase.optional_test_case()),
42                  ("promote", ProtoOpTestBase.promote_test_case()),
43                  ("ragged", ProtoOpTestBase.ragged_test_case()),
44                  ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
45                  ("simple", ProtoOpTestBase.simple_test_case())]
46    if extension:
47      parameters.append(("extension", ProtoOpTestBase.extension_test_case()))
48    return parameters
49
50  @staticmethod
51  def defaults_test_case():
52    test_case = test_example_pb2.TestCase()
53    test_case.values.add()  # No fields specified, so we get all defaults.
54    test_case.shapes.append(1)
55    test_case.sizes.append(0)
56    field = test_case.fields.add()
57    field.name = "double_value_with_default"
58    field.dtype = types_pb2.DT_DOUBLE
59    field.value.double_value.append(1.0)
60    test_case.sizes.append(0)
61    field = test_case.fields.add()
62    field.name = "float_value_with_default"
63    field.dtype = types_pb2.DT_FLOAT
64    field.value.float_value.append(2.0)
65    test_case.sizes.append(0)
66    field = test_case.fields.add()
67    field.name = "int64_value_with_default"
68    field.dtype = types_pb2.DT_INT64
69    field.value.int64_value.append(3)
70    test_case.sizes.append(0)
71    field = test_case.fields.add()
72    field.name = "sfixed64_value_with_default"
73    field.dtype = types_pb2.DT_INT64
74    field.value.int64_value.append(11)
75    test_case.sizes.append(0)
76    field = test_case.fields.add()
77    field.name = "sint64_value_with_default"
78    field.dtype = types_pb2.DT_INT64
79    field.value.int64_value.append(13)
80    test_case.sizes.append(0)
81    field = test_case.fields.add()
82    field.name = "uint64_value_with_default"
83    field.dtype = types_pb2.DT_UINT64
84    field.value.uint64_value.append(4)
85    test_case.sizes.append(0)
86    field = test_case.fields.add()
87    field.name = "fixed64_value_with_default"
88    field.dtype = types_pb2.DT_UINT64
89    field.value.uint64_value.append(6)
90    test_case.sizes.append(0)
91    field = test_case.fields.add()
92    field.name = "int32_value_with_default"
93    field.dtype = types_pb2.DT_INT32
94    field.value.int32_value.append(5)
95    test_case.sizes.append(0)
96    field = test_case.fields.add()
97    field.name = "sfixed32_value_with_default"
98    field.dtype = types_pb2.DT_INT32
99    field.value.int32_value.append(10)
100    test_case.sizes.append(0)
101    field = test_case.fields.add()
102    field.name = "sint32_value_with_default"
103    field.dtype = types_pb2.DT_INT32
104    field.value.int32_value.append(12)
105    test_case.sizes.append(0)
106    field = test_case.fields.add()
107    field.name = "uint32_value_with_default"
108    field.dtype = types_pb2.DT_UINT32
109    field.value.uint32_value.append(9)
110    test_case.sizes.append(0)
111    field = test_case.fields.add()
112    field.name = "fixed32_value_with_default"
113    field.dtype = types_pb2.DT_UINT32
114    field.value.uint32_value.append(7)
115    test_case.sizes.append(0)
116    field = test_case.fields.add()
117    field.name = "bool_value_with_default"
118    field.dtype = types_pb2.DT_BOOL
119    field.value.bool_value.append(True)
120    test_case.sizes.append(0)
121    field = test_case.fields.add()
122    field.name = "string_value_with_default"
123    field.dtype = types_pb2.DT_STRING
124    field.value.string_value.append("a")
125    test_case.sizes.append(0)
126    field = test_case.fields.add()
127    field.name = "bytes_value_with_default"
128    field.dtype = types_pb2.DT_STRING
129    field.value.string_value.append("a longer default string")
130    test_case.sizes.append(0)
131    field = test_case.fields.add()
132    field.name = "enum_value_with_default"
133    field.dtype = types_pb2.DT_INT32
134    field.value.enum_value.append(test_example_pb2.Color.GREEN)
135    return test_case
136
137  @staticmethod
138  def minmax_test_case():
139    test_case = test_example_pb2.TestCase()
140    value = test_case.values.add()
141    value.double_value.append(-1.7976931348623158e+308)
142    value.double_value.append(2.2250738585072014e-308)
143    value.double_value.append(1.7976931348623158e+308)
144    value.float_value.append(-3.402823466e+38)
145    value.float_value.append(1.175494351e-38)
146    value.float_value.append(3.402823466e+38)
147    value.int64_value.append(-9223372036854775808)
148    value.int64_value.append(9223372036854775807)
149    value.sfixed64_value.append(-9223372036854775808)
150    value.sfixed64_value.append(9223372036854775807)
151    value.sint64_value.append(-9223372036854775808)
152    value.sint64_value.append(9223372036854775807)
153    value.uint64_value.append(0)
154    value.uint64_value.append(18446744073709551615)
155    value.fixed64_value.append(0)
156    value.fixed64_value.append(18446744073709551615)
157    value.int32_value.append(-2147483648)
158    value.int32_value.append(2147483647)
159    value.sfixed32_value.append(-2147483648)
160    value.sfixed32_value.append(2147483647)
161    value.sint32_value.append(-2147483648)
162    value.sint32_value.append(2147483647)
163    value.uint32_value.append(0)
164    value.uint32_value.append(4294967295)
165    value.fixed32_value.append(0)
166    value.fixed32_value.append(4294967295)
167    value.bool_value.append(False)
168    value.bool_value.append(True)
169    value.string_value.append("")
170    value.string_value.append("I refer to the infinite.")
171    test_case.shapes.append(1)
172    test_case.sizes.append(3)
173    field = test_case.fields.add()
174    field.name = "double_value"
175    field.dtype = types_pb2.DT_DOUBLE
176    field.value.double_value.append(-1.7976931348623158e+308)
177    field.value.double_value.append(2.2250738585072014e-308)
178    field.value.double_value.append(1.7976931348623158e+308)
179    test_case.sizes.append(3)
180    field = test_case.fields.add()
181    field.name = "float_value"
182    field.dtype = types_pb2.DT_FLOAT
183    field.value.float_value.append(-3.402823466e+38)
184    field.value.float_value.append(1.175494351e-38)
185    field.value.float_value.append(3.402823466e+38)
186    test_case.sizes.append(2)
187    field = test_case.fields.add()
188    field.name = "int64_value"
189    field.dtype = types_pb2.DT_INT64
190    field.value.int64_value.append(-9223372036854775808)
191    field.value.int64_value.append(9223372036854775807)
192    test_case.sizes.append(2)
193    field = test_case.fields.add()
194    field.name = "sfixed64_value"
195    field.dtype = types_pb2.DT_INT64
196    field.value.int64_value.append(-9223372036854775808)
197    field.value.int64_value.append(9223372036854775807)
198    test_case.sizes.append(2)
199    field = test_case.fields.add()
200    field.name = "sint64_value"
201    field.dtype = types_pb2.DT_INT64
202    field.value.int64_value.append(-9223372036854775808)
203    field.value.int64_value.append(9223372036854775807)
204    test_case.sizes.append(2)
205    field = test_case.fields.add()
206    field.name = "uint64_value"
207    field.dtype = types_pb2.DT_UINT64
208    field.value.uint64_value.append(0)
209    field.value.uint64_value.append(18446744073709551615)
210    test_case.sizes.append(2)
211    field = test_case.fields.add()
212    field.name = "fixed64_value"
213    field.dtype = types_pb2.DT_UINT64
214    field.value.uint64_value.append(0)
215    field.value.uint64_value.append(18446744073709551615)
216    test_case.sizes.append(2)
217    field = test_case.fields.add()
218    field.name = "int32_value"
219    field.dtype = types_pb2.DT_INT32
220    field.value.int32_value.append(-2147483648)
221    field.value.int32_value.append(2147483647)
222    test_case.sizes.append(2)
223    field = test_case.fields.add()
224    field.name = "sfixed32_value"
225    field.dtype = types_pb2.DT_INT32
226    field.value.int32_value.append(-2147483648)
227    field.value.int32_value.append(2147483647)
228    test_case.sizes.append(2)
229    field = test_case.fields.add()
230    field.name = "sint32_value"
231    field.dtype = types_pb2.DT_INT32
232    field.value.int32_value.append(-2147483648)
233    field.value.int32_value.append(2147483647)
234    test_case.sizes.append(2)
235    field = test_case.fields.add()
236    field.name = "uint32_value"
237    field.dtype = types_pb2.DT_UINT32
238    field.value.uint32_value.append(0)
239    field.value.uint32_value.append(4294967295)
240    test_case.sizes.append(2)
241    field = test_case.fields.add()
242    field.name = "fixed32_value"
243    field.dtype = types_pb2.DT_UINT32
244    field.value.uint32_value.append(0)
245    field.value.uint32_value.append(4294967295)
246    test_case.sizes.append(2)
247    field = test_case.fields.add()
248    field.name = "bool_value"
249    field.dtype = types_pb2.DT_BOOL
250    field.value.bool_value.append(False)
251    field.value.bool_value.append(True)
252    test_case.sizes.append(2)
253    field = test_case.fields.add()
254    field.name = "string_value"
255    field.dtype = types_pb2.DT_STRING
256    field.value.string_value.append("")
257    field.value.string_value.append("I refer to the infinite.")
258    return test_case
259
260  @staticmethod
261  def nested_test_case():
262    test_case = test_example_pb2.TestCase()
263    value = test_case.values.add()
264    message_value = value.message_value.add()
265    message_value.double_value = 23.5
266    test_case.shapes.append(1)
267    test_case.sizes.append(1)
268    field = test_case.fields.add()
269    field.name = "message_value"
270    field.dtype = types_pb2.DT_STRING
271    message_value = field.value.message_value.add()
272    message_value.double_value = 23.5
273    return test_case
274
275  @staticmethod
276  def optional_test_case():
277    test_case = test_example_pb2.TestCase()
278    value = test_case.values.add()
279    value.bool_value.append(True)
280    test_case.shapes.append(1)
281    test_case.sizes.append(1)
282    field = test_case.fields.add()
283    field.name = "bool_value"
284    field.dtype = types_pb2.DT_BOOL
285    field.value.bool_value.append(True)
286    test_case.sizes.append(0)
287    field = test_case.fields.add()
288    field.name = "double_value"
289    field.dtype = types_pb2.DT_DOUBLE
290    field.value.double_value.append(0.0)
291    return test_case
292
293  @staticmethod
294  def promote_test_case():
295    test_case = test_example_pb2.TestCase()
296    value = test_case.values.add()
297    value.sint32_value.append(2147483647)
298    value.sfixed32_value.append(2147483647)
299    value.int32_value.append(2147483647)
300    value.fixed32_value.append(4294967295)
301    value.uint32_value.append(4294967295)
302    test_case.shapes.append(1)
303    test_case.sizes.append(1)
304    field = test_case.fields.add()
305    field.name = "sint32_value"
306    field.dtype = types_pb2.DT_INT64
307    field.value.int64_value.append(2147483647)
308    test_case.sizes.append(1)
309    field = test_case.fields.add()
310    field.name = "sfixed32_value"
311    field.dtype = types_pb2.DT_INT64
312    field.value.int64_value.append(2147483647)
313    test_case.sizes.append(1)
314    field = test_case.fields.add()
315    field.name = "int32_value"
316    field.dtype = types_pb2.DT_INT64
317    field.value.int64_value.append(2147483647)
318    test_case.sizes.append(1)
319    field = test_case.fields.add()
320    field.name = "fixed32_value"
321    field.dtype = types_pb2.DT_UINT64
322    field.value.uint64_value.append(4294967295)
323    test_case.sizes.append(1)
324    field = test_case.fields.add()
325    field.name = "uint32_value"
326    field.dtype = types_pb2.DT_UINT64
327    field.value.uint64_value.append(4294967295)
328    return test_case
329
330  @staticmethod
331  def ragged_test_case():
332    test_case = test_example_pb2.TestCase()
333    value = test_case.values.add()
334    value.double_value.append(23.5)
335    value.double_value.append(123.0)
336    value.bool_value.append(True)
337    value = test_case.values.add()
338    value.double_value.append(3.1)
339    value.bool_value.append(False)
340    test_case.shapes.append(2)
341    test_case.sizes.append(2)
342    test_case.sizes.append(1)
343    test_case.sizes.append(1)
344    test_case.sizes.append(1)
345    field = test_case.fields.add()
346    field.name = "double_value"
347    field.dtype = types_pb2.DT_DOUBLE
348    field.value.double_value.append(23.5)
349    field.value.double_value.append(123.0)
350    field.value.double_value.append(3.1)
351    field.value.double_value.append(0.0)
352    field = test_case.fields.add()
353    field.name = "bool_value"
354    field.dtype = types_pb2.DT_BOOL
355    field.value.bool_value.append(True)
356    field.value.bool_value.append(False)
357    return test_case
358
359  @staticmethod
360  def shaped_batch_test_case():
361    test_case = test_example_pb2.TestCase()
362    value = test_case.values.add()
363    value.double_value.append(23.5)
364    value.bool_value.append(True)
365    value = test_case.values.add()
366    value.double_value.append(44.0)
367    value.bool_value.append(False)
368    value = test_case.values.add()
369    value.double_value.append(3.14159)
370    value.bool_value.append(True)
371    value = test_case.values.add()
372    value.double_value.append(1.414)
373    value.bool_value.append(True)
374    value = test_case.values.add()
375    value.double_value.append(-32.2)
376    value.bool_value.append(False)
377    value = test_case.values.add()
378    value.double_value.append(0.0001)
379    value.bool_value.append(True)
380    test_case.shapes.append(3)
381    test_case.shapes.append(2)
382    for _ in range(12):
383      test_case.sizes.append(1)
384    field = test_case.fields.add()
385    field.name = "double_value"
386    field.dtype = types_pb2.DT_DOUBLE
387    field.value.double_value.append(23.5)
388    field.value.double_value.append(44.0)
389    field.value.double_value.append(3.14159)
390    field.value.double_value.append(1.414)
391    field.value.double_value.append(-32.2)
392    field.value.double_value.append(0.0001)
393    field = test_case.fields.add()
394    field.name = "bool_value"
395    field.dtype = types_pb2.DT_BOOL
396    field.value.bool_value.append(True)
397    field.value.bool_value.append(False)
398    field.value.bool_value.append(True)
399    field.value.bool_value.append(True)
400    field.value.bool_value.append(False)
401    field.value.bool_value.append(True)
402    return test_case
403
404  @staticmethod
405  def extension_test_case():
406    test_case = test_example_pb2.TestCase()
407    value = test_case.values.add()
408    message_value = value.Extensions[test_example_pb2.ext_value].add()
409    message_value.double_value = 23.5
410    test_case.shapes.append(1)
411    test_case.sizes.append(1)
412    field = test_case.fields.add()
413    field.name = test_example_pb2.ext_value.full_name
414    field.dtype = types_pb2.DT_STRING
415    message_value = field.value.Extensions[test_example_pb2.ext_value].add()
416    message_value.double_value = 23.5
417    return test_case
418
419  @staticmethod
420  def simple_test_case():
421    test_case = test_example_pb2.TestCase()
422    value = test_case.values.add()
423    value.double_value.append(23.5)
424    value.bool_value.append(True)
425    value.enum_value.append(test_example_pb2.Color.INDIGO)
426    test_case.shapes.append(1)
427    test_case.sizes.append(1)
428    field = test_case.fields.add()
429    field.name = "double_value"
430    field.dtype = types_pb2.DT_DOUBLE
431    field.value.double_value.append(23.5)
432    test_case.sizes.append(1)
433    field = test_case.fields.add()
434    field.name = "bool_value"
435    field.dtype = types_pb2.DT_BOOL
436    field.value.bool_value.append(True)
437    test_case.sizes.append(1)
438    field = test_case.fields.add()
439    field.name = "enum_value"
440    field.dtype = types_pb2.DT_INT32
441    field.value.enum_value.append(test_example_pb2.Color.INDIGO)
442    return test_case
443