1# ============================================================================= 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================= 16"""Tests for decode_proto op.""" 17 18# Python3 preparedness imports. 19import itertools 20 21from absl.testing import parameterized 22import numpy as np 23 24 25from google.protobuf import text_format 26 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.kernel_tests.proto import proto_op_test_base as test_base 30from tensorflow.python.kernel_tests.proto import test_example_pb2 31 32 33class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): 34 """Base class for testing proto decoding ops.""" 35 36 def __init__(self, decode_module, methodName='runTest'): # pylint: disable=invalid-name 37 """DecodeProtoOpTestBase initializer. 38 39 Args: 40 decode_module: a module containing the `decode_proto_op` method 41 methodName: the name of the test method (same as for test.TestCase) 42 """ 43 44 super(DecodeProtoOpTestBase, self).__init__(methodName) 45 self._decode_module = decode_module 46 47 def _compareValues(self, fd, vs, evs): 48 """Compare lists/arrays of field values.""" 49 50 if len(vs) != len(evs): 51 self.fail('Field %s decoded %d outputs, expected %d' % 52 (fd.name, len(vs), len(evs))) 53 for i, ev in enumerate(evs): 54 # Special case fuzzy match for float32. TensorFlow seems to mess with 55 # MAX_FLT slightly and the test doesn't work otherwise. 56 # TODO(nix): ask on TF list about why MAX_FLT doesn't pass through. 57 if fd.cpp_type == fd.CPPTYPE_FLOAT: 58 # Numpy isclose() is better than assertIsClose() which uses an absolute 59 # value comparison. 60 self.assertTrue( 61 np.isclose(vs[i], ev), 'expected %r, actual %r' % (ev, vs[i])) 62 elif fd.cpp_type == fd.CPPTYPE_STRING: 63 # In Python3 string tensor values will be represented as bytes, so we 64 # reencode the proto values to match that. 65 self.assertEqual(vs[i], ev.encode('ascii')) 66 else: 67 # Doubles and other types pass through unscathed. 68 self.assertEqual(vs[i], ev) 69 70 def _compareProtos(self, batch_shape, sizes, fields, field_dict): 71 """Compare protos of type TestValue. 72 73 Args: 74 batch_shape: the shape of the input tensor of serialized messages. 75 sizes: int matrix of repeat counts returned by decode_proto 76 fields: list of test_example_pb2.FieldSpec (types and expected values) 77 field_dict: map from field names to decoded numpy tensors of values 78 """ 79 80 # Check that expected values match. 81 for field in fields: 82 values = field_dict[field.name] 83 self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype) 84 85 if 'ext_value' in field.name: 86 fd = test_example_pb2.PrimitiveValue() 87 else: 88 fd = field.value.DESCRIPTOR.fields_by_name[field.name] 89 90 # Values has the same shape as the input plus an extra 91 # dimension for repeats. 92 self.assertEqual(list(values.shape)[:-1], batch_shape) 93 94 # Nested messages are represented as TF strings, requiring 95 # some special handling. 96 if field.name == 'message_value' or 'ext_value' in field.name: 97 vs = [] 98 for buf in values.flat: 99 msg = test_example_pb2.PrimitiveValue() 100 msg.ParseFromString(buf) 101 vs.append(msg) 102 if 'ext_value' in field.name: 103 evs = field.value.Extensions[test_example_pb2.ext_value] 104 else: 105 evs = getattr(field.value, field.name) 106 if len(vs) != len(evs): 107 self.fail('Field %s decoded %d outputs, expected %d' % 108 (fd.name, len(vs), len(evs))) 109 for v, ev in zip(vs, evs): 110 self.assertEqual(v, ev) 111 continue 112 113 tf_type_to_primitive_value_field = { 114 dtypes.bool: 115 'bool_value', 116 dtypes.float32: 117 'float_value', 118 dtypes.float64: 119 'double_value', 120 dtypes.int8: 121 'int8_value', 122 dtypes.int32: 123 'int32_value', 124 dtypes.int64: 125 'int64_value', 126 dtypes.string: 127 'string_value', 128 dtypes.uint8: 129 'uint8_value', 130 dtypes.uint32: 131 'uint32_value', 132 dtypes.uint64: 133 'uint64_value', 134 } 135 if field.name in ['enum_value', 'enum_value_with_default']: 136 tf_field_name = 'enum_value' 137 else: 138 tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) 139 if tf_field_name is None: 140 self.fail('Unhandled tensorflow type %d' % field.dtype) 141 142 self._compareValues(fd, values.flat, 143 getattr(field.value, tf_field_name)) 144 145 def _runDecodeProtoTests(self, fields, case_sizes, batch_shape, batch, 146 message_type, message_format, sanitize, 147 force_disordered=False): 148 """Run decode tests on a batch of messages. 149 150 Args: 151 fields: list of test_example_pb2.FieldSpec (types and expected values) 152 case_sizes: expected sizes array 153 batch_shape: the shape of the input tensor of serialized messages 154 batch: list of serialized messages 155 message_type: descriptor name for messages 156 message_format: format of messages, 'text' or 'binary' 157 sanitize: whether to sanitize binary protobuf inputs 158 force_disordered: whether to force fields encoded out of order. 159 """ 160 161 if force_disordered: 162 # Exercise code path that handles out-of-order fields by prepending extra 163 # fields with tag numbers higher than any real field. Note that this won't 164 # work with sanitization because that forces reserialization using a 165 # trusted decoder and encoder. 166 assert not sanitize 167 extra_fields = test_example_pb2.ExtraFields() 168 extra_fields.string_value = 'IGNORE ME' 169 extra_fields.bool_value = False 170 extra_msg = extra_fields.SerializeToString() 171 batch = [extra_msg + msg for msg in batch] 172 173 # Numpy silently truncates the strings if you don't specify dtype=object. 174 batch = np.array(batch, dtype=object) 175 batch = np.reshape(batch, batch_shape) 176 177 field_names = [f.name for f in fields] 178 output_types = [f.dtype for f in fields] 179 180 with self.cached_session() as sess: 181 sizes, vtensor = self._decode_module.decode_proto( 182 batch, 183 message_type=message_type, 184 field_names=field_names, 185 output_types=output_types, 186 message_format=message_format, 187 sanitize=sanitize) 188 189 vlist = sess.run([sizes] + vtensor) 190 sizes = vlist[0] 191 # Values is a list of tensors, one for each field. 192 value_tensors = vlist[1:] 193 194 # Check that the repeat sizes are correct. 195 self.assertTrue( 196 np.all(np.array(sizes.shape) == batch_shape + [len(field_names)])) 197 198 # Check that the decoded sizes match the expected sizes. 199 self.assertEqual(len(sizes.flat), len(case_sizes)) 200 self.assertTrue( 201 np.all(sizes.flat == np.array( 202 case_sizes, dtype=np.int32))) 203 204 field_dict = dict(zip(field_names, value_tensors)) 205 206 self._compareProtos(batch_shape, sizes, fields, field_dict) 207 208 @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) 209 def testBinary(self, case): 210 batch = [value.SerializeToString() for value in case.values] 211 self._runDecodeProtoTests( 212 case.fields, 213 case.sizes, 214 list(case.shapes), 215 batch, 216 'tensorflow.contrib.proto.TestValue', 217 'binary', 218 sanitize=False) 219 220 @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) 221 def testBinaryDisordered(self, case): 222 batch = [value.SerializeToString() for value in case.values] 223 self._runDecodeProtoTests( 224 case.fields, 225 case.sizes, 226 list(case.shapes), 227 batch, 228 'tensorflow.contrib.proto.TestValue', 229 'binary', 230 sanitize=False, 231 force_disordered=True) 232 233 @parameterized.named_parameters( 234 *test_base.ProtoOpTestBase.named_parameters(extension=False)) 235 def testPacked(self, case): 236 # Now try with the packed serialization. 237 # 238 # We test the packed representations by loading the same test case using 239 # PackedTestValue instead of TestValue. To do this we rely on the text 240 # format being the same for packed and unpacked fields, and reparse the 241 # test message using the packed version of the proto. 242 packed_batch = [ 243 # Note: float_format='.17g' is necessary to ensure preservation of 244 # doubles and floats in text format. 245 text_format.Parse( 246 text_format.MessageToString(value, float_format='.17g'), 247 test_example_pb2.PackedTestValue()).SerializeToString() 248 for value in case.values 249 ] 250 251 self._runDecodeProtoTests( 252 case.fields, 253 case.sizes, 254 list(case.shapes), 255 packed_batch, 256 'tensorflow.contrib.proto.PackedTestValue', 257 'binary', 258 sanitize=False) 259 260 @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) 261 def testText(self, case): 262 # Note: float_format='.17g' is necessary to ensure preservation of 263 # doubles and floats in text format. 264 text_batch = [ 265 text_format.MessageToString( 266 value, float_format='.17g') for value in case.values 267 ] 268 269 self._runDecodeProtoTests( 270 case.fields, 271 case.sizes, 272 list(case.shapes), 273 text_batch, 274 'tensorflow.contrib.proto.TestValue', 275 'text', 276 sanitize=False) 277 278 @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters()) 279 def testSanitizerGood(self, case): 280 batch = [value.SerializeToString() for value in case.values] 281 self._runDecodeProtoTests( 282 case.fields, 283 case.sizes, 284 list(case.shapes), 285 batch, 286 'tensorflow.contrib.proto.TestValue', 287 'binary', 288 sanitize=True) 289 290 @parameterized.parameters((False), (True)) 291 def testCorruptProtobuf(self, sanitize): 292 corrupt_proto = 'This is not a binary protobuf' 293 294 # Numpy silently truncates the strings if you don't specify dtype=object. 295 batch = np.array(corrupt_proto, dtype=object) 296 msg_type = 'tensorflow.contrib.proto.TestCase' 297 field_names = ['sizes'] 298 field_types = [dtypes.int32] 299 300 with self.assertRaisesRegexp( 301 errors.DataLossError, 'Unable to parse binary protobuf' 302 '|Failed to consume entire buffer'): 303 self.evaluate( 304 self._decode_module.decode_proto( 305 batch, 306 message_type=msg_type, 307 field_names=field_names, 308 output_types=field_types, 309 sanitize=sanitize)) 310 311 def testOutOfOrderRepeated(self): 312 fragments = [ 313 test_example_pb2.TestValue(double_value=[1.0]).SerializeToString(), 314 test_example_pb2.TestValue( 315 message_value=[test_example_pb2.PrimitiveValue( 316 string_value='abc')]).SerializeToString(), 317 test_example_pb2.TestValue( 318 message_value=[test_example_pb2.PrimitiveValue( 319 string_value='def')]).SerializeToString() 320 ] 321 all_fields_to_parse = ['double_value', 'message_value'] 322 field_types = { 323 'double_value': dtypes.double, 324 'message_value': dtypes.string, 325 } 326 # Test against all 3! permutations of fragments, and for each permutation 327 # test parsing all possible combination of 2 fields. 328 for indices in itertools.permutations(range(len(fragments))): 329 proto = b''.join(fragments[i] for i in indices) 330 for i in indices: 331 if i == 1: 332 expected_message_values = [ 333 test_example_pb2.PrimitiveValue( 334 string_value='abc').SerializeToString(), 335 test_example_pb2.PrimitiveValue( 336 string_value='def').SerializeToString(), 337 ] 338 break 339 if i == 2: 340 expected_message_values = [ 341 test_example_pb2.PrimitiveValue( 342 string_value='def').SerializeToString(), 343 test_example_pb2.PrimitiveValue( 344 string_value='abc').SerializeToString(), 345 ] 346 break 347 348 expected_field_values = { 349 'double_value': [[1.0]], 350 'message_value': [expected_message_values], 351 } 352 353 for num_fields_to_parse in range(len(all_fields_to_parse)): 354 for comb in itertools.combinations( 355 all_fields_to_parse, num_fields_to_parse): 356 parsed_values = self.evaluate( 357 self._decode_module.decode_proto( 358 [proto], 359 message_type='tensorflow.contrib.proto.TestValue', 360 field_names=comb, 361 output_types=[field_types[f] for f in comb], 362 sanitize=False)).values 363 self.assertLen(parsed_values, len(comb)) 364 for field_name, parsed in zip(comb, parsed_values): 365 self.assertAllEqual(parsed, expected_field_values[field_name], 366 'perm: {}, comb: {}'.format(indices, comb)) 367