xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/proto/decode_proto_op_test_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# =============================================================================
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# =============================================================================
16"""Tests for decode_proto op."""
17
18# Python3 preparedness imports.
19import itertools
20
21from absl.testing import parameterized
22import numpy as np
23
24
25from google.protobuf import text_format
26
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.kernel_tests.proto import proto_op_test_base as test_base
30from tensorflow.python.kernel_tests.proto import test_example_pb2
31
32
33class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
34  """Base class for testing proto decoding ops."""
35
36  def __init__(self, decode_module, methodName='runTest'):  # pylint: disable=invalid-name
37    """DecodeProtoOpTestBase initializer.
38
39    Args:
40      decode_module: a module containing the `decode_proto_op` method
41      methodName: the name of the test method (same as for test.TestCase)
42    """
43
44    super(DecodeProtoOpTestBase, self).__init__(methodName)
45    self._decode_module = decode_module
46
47  def _compareValues(self, fd, vs, evs):
48    """Compare lists/arrays of field values."""
49
50    if len(vs) != len(evs):
51      self.fail('Field %s decoded %d outputs, expected %d' %
52                (fd.name, len(vs), len(evs)))
53    for i, ev in enumerate(evs):
54      # Special case fuzzy match for float32. TensorFlow seems to mess with
55      # MAX_FLT slightly and the test doesn't work otherwise.
56      # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through.
57      if fd.cpp_type == fd.CPPTYPE_FLOAT:
58        # Numpy isclose() is better than assertIsClose() which uses an absolute
59        # value comparison.
60        self.assertTrue(
61            np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i]))
62      elif fd.cpp_type == fd.CPPTYPE_STRING:
63        # In Python3 string tensor values will be represented as bytes, so we
64        # reencode the proto values to match that.
65        self.assertEqual(vs[i], ev.encode('ascii'))
66      else:
67        # Doubles and other types pass through unscathed.
68        self.assertEqual(vs[i], ev)
69
70  def _compareProtos(self, batch_shape, sizes, fields, field_dict):
71    """Compare protos of type TestValue.
72
73    Args:
74      batch_shape: the shape of the input tensor of serialized messages.
75      sizes: int matrix of repeat counts returned by decode_proto
76      fields: list of test_example_pb2.FieldSpec (types and expected values)
77      field_dict: map from field names to decoded numpy tensors of values
78    """
79
80    # Check that expected values match.
81    for field in fields:
82      values = field_dict[field.name]
83      self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
84
85      if 'ext_value' in field.name:
86        fd = test_example_pb2.PrimitiveValue()
87      else:
88        fd = field.value.DESCRIPTOR.fields_by_name[field.name]
89
90      # Values has the same shape as the input plus an extra
91      # dimension for repeats.
92      self.assertEqual(list(values.shape)[:-1], batch_shape)
93
94      # Nested messages are represented as TF strings, requiring
95      # some special handling.
96      if field.name == 'message_value' or 'ext_value' in field.name:
97        vs = []
98        for buf in values.flat:
99          msg = test_example_pb2.PrimitiveValue()
100          msg.ParseFromString(buf)
101          vs.append(msg)
102        if 'ext_value' in field.name:
103          evs = field.value.Extensions[test_example_pb2.ext_value]
104        else:
105          evs = getattr(field.value, field.name)
106        if len(vs) != len(evs):
107          self.fail('Field %s decoded %d outputs, expected %d' %
108                    (fd.name, len(vs), len(evs)))
109        for v, ev in zip(vs, evs):
110          self.assertEqual(v, ev)
111        continue
112
113      tf_type_to_primitive_value_field = {
114          dtypes.bool:
115              'bool_value',
116          dtypes.float32:
117              'float_value',
118          dtypes.float64:
119              'double_value',
120          dtypes.int8:
121              'int8_value',
122          dtypes.int32:
123              'int32_value',
124          dtypes.int64:
125              'int64_value',
126          dtypes.string:
127              'string_value',
128          dtypes.uint8:
129              'uint8_value',
130          dtypes.uint32:
131              'uint32_value',
132          dtypes.uint64:
133              'uint64_value',
134      }
135      if field.name in ['enum_value', 'enum_value_with_default']:
136        tf_field_name = 'enum_value'
137      else:
138        tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
139      if tf_field_name is None:
140        self.fail('Unhandled tensorflow type %d' % field.dtype)
141
142      self._compareValues(fd, values.flat,
143                          getattr(field.value, tf_field_name))
144
145  def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch,
146                           message_type, message_format, sanitize,
147                           force_disordered=False):
148    """Run decode tests on a batch of messages.
149
150    Args:
151      fields: list of test_example_pb2.FieldSpec (types and expected values)
152      case_sizes: expected sizes array
153      batch_shape: the shape of the input tensor of serialized messages
154      batch: list of serialized messages
155      message_type: descriptor name for messages
156      message_format: format of messages, 'text' or 'binary'
157      sanitize: whether to sanitize binary protobuf inputs
158      force_disordered: whether to force fields encoded out of order.
159    """
160
161    if force_disordered:
162      # Exercise code path that handles out-of-order fields by prepending extra
163      # fields with tag numbers higher than any real field. Note that this won't
164      # work with sanitization because that forces reserialization using a
165      # trusted decoder and encoder.
166      assert not sanitize
167      extra_fields = test_example_pb2.ExtraFields()
168      extra_fields.string_value = 'IGNORE ME'
169      extra_fields.bool_value = False
170      extra_msg = extra_fields.SerializeToString()
171      batch = [extra_msg + msg for msg in batch]
172
173    # Numpy silently truncates the strings if you don't specify dtype=object.
174    batch = np.array(batch, dtype=object)
175    batch = np.reshape(batch, batch_shape)
176
177    field_names = [f.name for f in fields]
178    output_types = [f.dtype for f in fields]
179
180    with self.cached_session() as sess:
181      sizes, vtensor = self._decode_module.decode_proto(
182          batch,
183          message_type=message_type,
184          field_names=field_names,
185          output_types=output_types,
186          message_format=message_format,
187          sanitize=sanitize)
188
189      vlist = sess.run([sizes] + vtensor)
190      sizes = vlist[0]
191      # Values is a list of tensors, one for each field.
192      value_tensors = vlist[1:]
193
194      # Check that the repeat sizes are correct.
195      self.assertTrue(
196          np.all(np.array(sizes.shape) == batch_shape + [len(field_names)]))
197
198      # Check that the decoded sizes match the expected sizes.
199      self.assertEqual(len(sizes.flat), len(case_sizes))
200      self.assertTrue(
201          np.all(sizes.flat == np.array(
202              case_sizes, dtype=np.int32)))
203
204      field_dict = dict(zip(field_names, value_tensors))
205
206      self._compareProtos(batch_shape, sizes, fields, field_dict)
207
208  @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
209  def testBinary(self, case):
210    batch = [value.SerializeToString() for value in case.values]
211    self._runDecodeProtoTests(
212        case.fields,
213        case.sizes,
214        list(case.shapes),
215        batch,
216        'tensorflow.contrib.proto.TestValue',
217        'binary',
218        sanitize=False)
219
220  @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
221  def testBinaryDisordered(self, case):
222    batch = [value.SerializeToString() for value in case.values]
223    self._runDecodeProtoTests(
224        case.fields,
225        case.sizes,
226        list(case.shapes),
227        batch,
228        'tensorflow.contrib.proto.TestValue',
229        'binary',
230        sanitize=False,
231        force_disordered=True)
232
233  @parameterized.named_parameters(
234      *test_base.ProtoOpTestBase.named_parameters(extension=False))
235  def testPacked(self, case):
236    # Now try with the packed serialization.
237    #
238    # We test the packed representations by loading the same test case using
239    # PackedTestValue instead of TestValue. To do this we rely on the text
240    # format being the same for packed and unpacked fields, and reparse the
241    # test message using the packed version of the proto.
242    packed_batch = [
243        # Note: float_format='.17g' is necessary to ensure preservation of
244        # doubles and floats in text format.
245        text_format.Parse(
246            text_format.MessageToString(value, float_format='.17g'),
247            test_example_pb2.PackedTestValue()).SerializeToString()
248        for value in case.values
249    ]
250
251    self._runDecodeProtoTests(
252        case.fields,
253        case.sizes,
254        list(case.shapes),
255        packed_batch,
256        'tensorflow.contrib.proto.PackedTestValue',
257        'binary',
258        sanitize=False)
259
260  @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
261  def testText(self, case):
262    # Note: float_format='.17g' is necessary to ensure preservation of
263    # doubles and floats in text format.
264    text_batch = [
265        text_format.MessageToString(
266            value, float_format='.17g') for value in case.values
267    ]
268
269    self._runDecodeProtoTests(
270        case.fields,
271        case.sizes,
272        list(case.shapes),
273        text_batch,
274        'tensorflow.contrib.proto.TestValue',
275        'text',
276        sanitize=False)
277
278  @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
279  def testSanitizerGood(self, case):
280    batch = [value.SerializeToString() for value in case.values]
281    self._runDecodeProtoTests(
282        case.fields,
283        case.sizes,
284        list(case.shapes),
285        batch,
286        'tensorflow.contrib.proto.TestValue',
287        'binary',
288        sanitize=True)
289
290  @parameterized.parameters((False), (True))
291  def testCorruptProtobuf(self, sanitize):
292    corrupt_proto = 'This is not a binary protobuf'
293
294    # Numpy silently truncates the strings if you don't specify dtype=object.
295    batch = np.array(corrupt_proto, dtype=object)
296    msg_type = 'tensorflow.contrib.proto.TestCase'
297    field_names = ['sizes']
298    field_types = [dtypes.int32]
299
300    with self.assertRaisesRegexp(
301        errors.DataLossError, 'Unable to parse binary protobuf'
302        '|Failed to consume entire buffer'):
303      self.evaluate(
304          self._decode_module.decode_proto(
305              batch,
306              message_type=msg_type,
307              field_names=field_names,
308              output_types=field_types,
309              sanitize=sanitize))
310
311  def testOutOfOrderRepeated(self):
312    fragments = [
313        test_example_pb2.TestValue(double_value=[1.0]).SerializeToString(),
314        test_example_pb2.TestValue(
315            message_value=[test_example_pb2.PrimitiveValue(
316                string_value='abc')]).SerializeToString(),
317        test_example_pb2.TestValue(
318            message_value=[test_example_pb2.PrimitiveValue(
319                string_value='def')]).SerializeToString()
320    ]
321    all_fields_to_parse = ['double_value', 'message_value']
322    field_types = {
323        'double_value': dtypes.double,
324        'message_value': dtypes.string,
325    }
326    # Test against all 3! permutations of fragments, and for each permutation
327    # test parsing all possible combination of 2 fields.
328    for indices in itertools.permutations(range(len(fragments))):
329      proto = b''.join(fragments[i] for i in indices)
330      for i in indices:
331        if i == 1:
332          expected_message_values = [
333              test_example_pb2.PrimitiveValue(
334                  string_value='abc').SerializeToString(),
335              test_example_pb2.PrimitiveValue(
336                  string_value='def').SerializeToString(),
337          ]
338          break
339        if i == 2:
340          expected_message_values = [
341              test_example_pb2.PrimitiveValue(
342                  string_value='def').SerializeToString(),
343              test_example_pb2.PrimitiveValue(
344                  string_value='abc').SerializeToString(),
345          ]
346          break
347
348      expected_field_values = {
349          'double_value': [[1.0]],
350          'message_value': [expected_message_values],
351      }
352
353      for num_fields_to_parse in range(len(all_fields_to_parse)):
354        for comb in itertools.combinations(
355            all_fields_to_parse, num_fields_to_parse):
356          parsed_values = self.evaluate(
357              self._decode_module.decode_proto(
358                  [proto],
359                  message_type='tensorflow.contrib.proto.TestValue',
360                  field_names=comb,
361                  output_types=[field_types[f] for f in comb],
362                  sanitize=False)).values
363          self.assertLen(parsed_values, len(comb))
364          for field_name, parsed in zip(comb, parsed_values):
365            self.assertAllEqual(parsed, expected_field_values[field_name],
366                                'perm: {}, comb: {}'.format(indices, comb))
367