1# ============================================================================= 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================= 16"""Test case base for testing proto operations.""" 17 18# Python3 preparedness imports. 19import ctypes as ct 20import os 21 22from tensorflow.core.framework import types_pb2 23from tensorflow.python.kernel_tests.proto import test_example_pb2 24from tensorflow.python.platform import test 25 26 27class ProtoOpTestBase(test.TestCase): 28 """Base class for testing proto decoding and encoding ops.""" 29 30 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 31 super(ProtoOpTestBase, self).__init__(methodName) 32 lib = os.path.join(os.path.dirname(__file__), "libtestexample.so") 33 if os.path.isfile(lib): 34 ct.cdll.LoadLibrary(lib) 35 36 @staticmethod 37 def named_parameters(extension=True): 38 parameters = [("defaults", ProtoOpTestBase.defaults_test_case()), 39 ("minmax", ProtoOpTestBase.minmax_test_case()), 40 ("nested", ProtoOpTestBase.nested_test_case()), 41 ("optional", ProtoOpTestBase.optional_test_case()), 42 ("promote", ProtoOpTestBase.promote_test_case()), 43 ("ragged", ProtoOpTestBase.ragged_test_case()), 44 ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), 45 ("simple", ProtoOpTestBase.simple_test_case())] 46 if extension: 47 parameters.append(("extension", ProtoOpTestBase.extension_test_case())) 48 return parameters 49 50 @staticmethod 51 def defaults_test_case(): 52 test_case = test_example_pb2.TestCase() 53 test_case.values.add() # No fields specified, so we get all defaults. 54 test_case.shapes.append(1) 55 test_case.sizes.append(0) 56 field = test_case.fields.add() 57 field.name = "double_value_with_default" 58 field.dtype = types_pb2.DT_DOUBLE 59 field.value.double_value.append(1.0) 60 test_case.sizes.append(0) 61 field = test_case.fields.add() 62 field.name = "float_value_with_default" 63 field.dtype = types_pb2.DT_FLOAT 64 field.value.float_value.append(2.0) 65 test_case.sizes.append(0) 66 field = test_case.fields.add() 67 field.name = "int64_value_with_default" 68 field.dtype = types_pb2.DT_INT64 69 field.value.int64_value.append(3) 70 test_case.sizes.append(0) 71 field = test_case.fields.add() 72 field.name = "sfixed64_value_with_default" 73 field.dtype = types_pb2.DT_INT64 74 field.value.int64_value.append(11) 75 test_case.sizes.append(0) 76 field = test_case.fields.add() 77 field.name = "sint64_value_with_default" 78 field.dtype = types_pb2.DT_INT64 79 field.value.int64_value.append(13) 80 test_case.sizes.append(0) 81 field = test_case.fields.add() 82 field.name = "uint64_value_with_default" 83 field.dtype = types_pb2.DT_UINT64 84 field.value.uint64_value.append(4) 85 test_case.sizes.append(0) 86 field = test_case.fields.add() 87 field.name = "fixed64_value_with_default" 88 field.dtype = types_pb2.DT_UINT64 89 field.value.uint64_value.append(6) 90 test_case.sizes.append(0) 91 field = test_case.fields.add() 92 field.name = "int32_value_with_default" 93 field.dtype = types_pb2.DT_INT32 94 field.value.int32_value.append(5) 95 test_case.sizes.append(0) 96 field = test_case.fields.add() 97 field.name = "sfixed32_value_with_default" 98 field.dtype = types_pb2.DT_INT32 99 field.value.int32_value.append(10) 100 test_case.sizes.append(0) 101 field = test_case.fields.add() 102 field.name = "sint32_value_with_default" 103 field.dtype = types_pb2.DT_INT32 104 field.value.int32_value.append(12) 105 test_case.sizes.append(0) 106 field = test_case.fields.add() 107 field.name = "uint32_value_with_default" 108 field.dtype = types_pb2.DT_UINT32 109 field.value.uint32_value.append(9) 110 test_case.sizes.append(0) 111 field = test_case.fields.add() 112 field.name = "fixed32_value_with_default" 113 field.dtype = types_pb2.DT_UINT32 114 field.value.uint32_value.append(7) 115 test_case.sizes.append(0) 116 field = test_case.fields.add() 117 field.name = "bool_value_with_default" 118 field.dtype = types_pb2.DT_BOOL 119 field.value.bool_value.append(True) 120 test_case.sizes.append(0) 121 field = test_case.fields.add() 122 field.name = "string_value_with_default" 123 field.dtype = types_pb2.DT_STRING 124 field.value.string_value.append("a") 125 test_case.sizes.append(0) 126 field = test_case.fields.add() 127 field.name = "bytes_value_with_default" 128 field.dtype = types_pb2.DT_STRING 129 field.value.string_value.append("a longer default string") 130 test_case.sizes.append(0) 131 field = test_case.fields.add() 132 field.name = "enum_value_with_default" 133 field.dtype = types_pb2.DT_INT32 134 field.value.enum_value.append(test_example_pb2.Color.GREEN) 135 return test_case 136 137 @staticmethod 138 def minmax_test_case(): 139 test_case = test_example_pb2.TestCase() 140 value = test_case.values.add() 141 value.double_value.append(-1.7976931348623158e+308) 142 value.double_value.append(2.2250738585072014e-308) 143 value.double_value.append(1.7976931348623158e+308) 144 value.float_value.append(-3.402823466e+38) 145 value.float_value.append(1.175494351e-38) 146 value.float_value.append(3.402823466e+38) 147 value.int64_value.append(-9223372036854775808) 148 value.int64_value.append(9223372036854775807) 149 value.sfixed64_value.append(-9223372036854775808) 150 value.sfixed64_value.append(9223372036854775807) 151 value.sint64_value.append(-9223372036854775808) 152 value.sint64_value.append(9223372036854775807) 153 value.uint64_value.append(0) 154 value.uint64_value.append(18446744073709551615) 155 value.fixed64_value.append(0) 156 value.fixed64_value.append(18446744073709551615) 157 value.int32_value.append(-2147483648) 158 value.int32_value.append(2147483647) 159 value.sfixed32_value.append(-2147483648) 160 value.sfixed32_value.append(2147483647) 161 value.sint32_value.append(-2147483648) 162 value.sint32_value.append(2147483647) 163 value.uint32_value.append(0) 164 value.uint32_value.append(4294967295) 165 value.fixed32_value.append(0) 166 value.fixed32_value.append(4294967295) 167 value.bool_value.append(False) 168 value.bool_value.append(True) 169 value.string_value.append("") 170 value.string_value.append("I refer to the infinite.") 171 test_case.shapes.append(1) 172 test_case.sizes.append(3) 173 field = test_case.fields.add() 174 field.name = "double_value" 175 field.dtype = types_pb2.DT_DOUBLE 176 field.value.double_value.append(-1.7976931348623158e+308) 177 field.value.double_value.append(2.2250738585072014e-308) 178 field.value.double_value.append(1.7976931348623158e+308) 179 test_case.sizes.append(3) 180 field = test_case.fields.add() 181 field.name = "float_value" 182 field.dtype = types_pb2.DT_FLOAT 183 field.value.float_value.append(-3.402823466e+38) 184 field.value.float_value.append(1.175494351e-38) 185 field.value.float_value.append(3.402823466e+38) 186 test_case.sizes.append(2) 187 field = test_case.fields.add() 188 field.name = "int64_value" 189 field.dtype = types_pb2.DT_INT64 190 field.value.int64_value.append(-9223372036854775808) 191 field.value.int64_value.append(9223372036854775807) 192 test_case.sizes.append(2) 193 field = test_case.fields.add() 194 field.name = "sfixed64_value" 195 field.dtype = types_pb2.DT_INT64 196 field.value.int64_value.append(-9223372036854775808) 197 field.value.int64_value.append(9223372036854775807) 198 test_case.sizes.append(2) 199 field = test_case.fields.add() 200 field.name = "sint64_value" 201 field.dtype = types_pb2.DT_INT64 202 field.value.int64_value.append(-9223372036854775808) 203 field.value.int64_value.append(9223372036854775807) 204 test_case.sizes.append(2) 205 field = test_case.fields.add() 206 field.name = "uint64_value" 207 field.dtype = types_pb2.DT_UINT64 208 field.value.uint64_value.append(0) 209 field.value.uint64_value.append(18446744073709551615) 210 test_case.sizes.append(2) 211 field = test_case.fields.add() 212 field.name = "fixed64_value" 213 field.dtype = types_pb2.DT_UINT64 214 field.value.uint64_value.append(0) 215 field.value.uint64_value.append(18446744073709551615) 216 test_case.sizes.append(2) 217 field = test_case.fields.add() 218 field.name = "int32_value" 219 field.dtype = types_pb2.DT_INT32 220 field.value.int32_value.append(-2147483648) 221 field.value.int32_value.append(2147483647) 222 test_case.sizes.append(2) 223 field = test_case.fields.add() 224 field.name = "sfixed32_value" 225 field.dtype = types_pb2.DT_INT32 226 field.value.int32_value.append(-2147483648) 227 field.value.int32_value.append(2147483647) 228 test_case.sizes.append(2) 229 field = test_case.fields.add() 230 field.name = "sint32_value" 231 field.dtype = types_pb2.DT_INT32 232 field.value.int32_value.append(-2147483648) 233 field.value.int32_value.append(2147483647) 234 test_case.sizes.append(2) 235 field = test_case.fields.add() 236 field.name = "uint32_value" 237 field.dtype = types_pb2.DT_UINT32 238 field.value.uint32_value.append(0) 239 field.value.uint32_value.append(4294967295) 240 test_case.sizes.append(2) 241 field = test_case.fields.add() 242 field.name = "fixed32_value" 243 field.dtype = types_pb2.DT_UINT32 244 field.value.uint32_value.append(0) 245 field.value.uint32_value.append(4294967295) 246 test_case.sizes.append(2) 247 field = test_case.fields.add() 248 field.name = "bool_value" 249 field.dtype = types_pb2.DT_BOOL 250 field.value.bool_value.append(False) 251 field.value.bool_value.append(True) 252 test_case.sizes.append(2) 253 field = test_case.fields.add() 254 field.name = "string_value" 255 field.dtype = types_pb2.DT_STRING 256 field.value.string_value.append("") 257 field.value.string_value.append("I refer to the infinite.") 258 return test_case 259 260 @staticmethod 261 def nested_test_case(): 262 test_case = test_example_pb2.TestCase() 263 value = test_case.values.add() 264 message_value = value.message_value.add() 265 message_value.double_value = 23.5 266 test_case.shapes.append(1) 267 test_case.sizes.append(1) 268 field = test_case.fields.add() 269 field.name = "message_value" 270 field.dtype = types_pb2.DT_STRING 271 message_value = field.value.message_value.add() 272 message_value.double_value = 23.5 273 return test_case 274 275 @staticmethod 276 def optional_test_case(): 277 test_case = test_example_pb2.TestCase() 278 value = test_case.values.add() 279 value.bool_value.append(True) 280 test_case.shapes.append(1) 281 test_case.sizes.append(1) 282 field = test_case.fields.add() 283 field.name = "bool_value" 284 field.dtype = types_pb2.DT_BOOL 285 field.value.bool_value.append(True) 286 test_case.sizes.append(0) 287 field = test_case.fields.add() 288 field.name = "double_value" 289 field.dtype = types_pb2.DT_DOUBLE 290 field.value.double_value.append(0.0) 291 return test_case 292 293 @staticmethod 294 def promote_test_case(): 295 test_case = test_example_pb2.TestCase() 296 value = test_case.values.add() 297 value.sint32_value.append(2147483647) 298 value.sfixed32_value.append(2147483647) 299 value.int32_value.append(2147483647) 300 value.fixed32_value.append(4294967295) 301 value.uint32_value.append(4294967295) 302 test_case.shapes.append(1) 303 test_case.sizes.append(1) 304 field = test_case.fields.add() 305 field.name = "sint32_value" 306 field.dtype = types_pb2.DT_INT64 307 field.value.int64_value.append(2147483647) 308 test_case.sizes.append(1) 309 field = test_case.fields.add() 310 field.name = "sfixed32_value" 311 field.dtype = types_pb2.DT_INT64 312 field.value.int64_value.append(2147483647) 313 test_case.sizes.append(1) 314 field = test_case.fields.add() 315 field.name = "int32_value" 316 field.dtype = types_pb2.DT_INT64 317 field.value.int64_value.append(2147483647) 318 test_case.sizes.append(1) 319 field = test_case.fields.add() 320 field.name = "fixed32_value" 321 field.dtype = types_pb2.DT_UINT64 322 field.value.uint64_value.append(4294967295) 323 test_case.sizes.append(1) 324 field = test_case.fields.add() 325 field.name = "uint32_value" 326 field.dtype = types_pb2.DT_UINT64 327 field.value.uint64_value.append(4294967295) 328 return test_case 329 330 @staticmethod 331 def ragged_test_case(): 332 test_case = test_example_pb2.TestCase() 333 value = test_case.values.add() 334 value.double_value.append(23.5) 335 value.double_value.append(123.0) 336 value.bool_value.append(True) 337 value = test_case.values.add() 338 value.double_value.append(3.1) 339 value.bool_value.append(False) 340 test_case.shapes.append(2) 341 test_case.sizes.append(2) 342 test_case.sizes.append(1) 343 test_case.sizes.append(1) 344 test_case.sizes.append(1) 345 field = test_case.fields.add() 346 field.name = "double_value" 347 field.dtype = types_pb2.DT_DOUBLE 348 field.value.double_value.append(23.5) 349 field.value.double_value.append(123.0) 350 field.value.double_value.append(3.1) 351 field.value.double_value.append(0.0) 352 field = test_case.fields.add() 353 field.name = "bool_value" 354 field.dtype = types_pb2.DT_BOOL 355 field.value.bool_value.append(True) 356 field.value.bool_value.append(False) 357 return test_case 358 359 @staticmethod 360 def shaped_batch_test_case(): 361 test_case = test_example_pb2.TestCase() 362 value = test_case.values.add() 363 value.double_value.append(23.5) 364 value.bool_value.append(True) 365 value = test_case.values.add() 366 value.double_value.append(44.0) 367 value.bool_value.append(False) 368 value = test_case.values.add() 369 value.double_value.append(3.14159) 370 value.bool_value.append(True) 371 value = test_case.values.add() 372 value.double_value.append(1.414) 373 value.bool_value.append(True) 374 value = test_case.values.add() 375 value.double_value.append(-32.2) 376 value.bool_value.append(False) 377 value = test_case.values.add() 378 value.double_value.append(0.0001) 379 value.bool_value.append(True) 380 test_case.shapes.append(3) 381 test_case.shapes.append(2) 382 for _ in range(12): 383 test_case.sizes.append(1) 384 field = test_case.fields.add() 385 field.name = "double_value" 386 field.dtype = types_pb2.DT_DOUBLE 387 field.value.double_value.append(23.5) 388 field.value.double_value.append(44.0) 389 field.value.double_value.append(3.14159) 390 field.value.double_value.append(1.414) 391 field.value.double_value.append(-32.2) 392 field.value.double_value.append(0.0001) 393 field = test_case.fields.add() 394 field.name = "bool_value" 395 field.dtype = types_pb2.DT_BOOL 396 field.value.bool_value.append(True) 397 field.value.bool_value.append(False) 398 field.value.bool_value.append(True) 399 field.value.bool_value.append(True) 400 field.value.bool_value.append(False) 401 field.value.bool_value.append(True) 402 return test_case 403 404 @staticmethod 405 def extension_test_case(): 406 test_case = test_example_pb2.TestCase() 407 value = test_case.values.add() 408 message_value = value.Extensions[test_example_pb2.ext_value].add() 409 message_value.double_value = 23.5 410 test_case.shapes.append(1) 411 test_case.sizes.append(1) 412 field = test_case.fields.add() 413 field.name = test_example_pb2.ext_value.full_name 414 field.dtype = types_pb2.DT_STRING 415 message_value = field.value.Extensions[test_example_pb2.ext_value].add() 416 message_value.double_value = 23.5 417 return test_case 418 419 @staticmethod 420 def simple_test_case(): 421 test_case = test_example_pb2.TestCase() 422 value = test_case.values.add() 423 value.double_value.append(23.5) 424 value.bool_value.append(True) 425 value.enum_value.append(test_example_pb2.Color.INDIGO) 426 test_case.shapes.append(1) 427 test_case.sizes.append(1) 428 field = test_case.fields.add() 429 field.name = "double_value" 430 field.dtype = types_pb2.DT_DOUBLE 431 field.value.double_value.append(23.5) 432 test_case.sizes.append(1) 433 field = test_case.fields.add() 434 field.name = "bool_value" 435 field.dtype = types_pb2.DT_BOOL 436 field.value.bool_value.append(True) 437 test_case.sizes.append(1) 438 field = test_case.fields.add() 439 field.name = "enum_value" 440 field.dtype = types_pb2.DT_INT32 441 field.value.enum_value.append(test_example_pb2.Color.INDIGO) 442 return test_case 443