1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.experimental.parse_example_dataset()."""
16
17import copy
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.core.example import example_pb2
23from tensorflow.core.example import feature_pb2
24from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsing_ops
25from tensorflow.python.data.kernel_tests import checkpoint_test_base
26from tensorflow.python.data.kernel_tests import test_base
27from tensorflow.python.data.kernel_tests import tf_record_test_base
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.data.ops import options as options_lib
30from tensorflow.python.eager import context
31from tensorflow.python.framework import combinations
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors_impl
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import sparse_tensor
36from tensorflow.python.ops import parsing_ops
37from tensorflow.python.ops.ragged import ragged_factory_ops
38from tensorflow.python.platform import test
39
40# Helpers for creating Example objects
41example = example_pb2.Example
42feature = feature_pb2.Feature
43features = lambda d: feature_pb2.Features(feature=d)
44bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v))
45int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v))
46float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v))
47# Helpers for creating SequenceExample objects
48feature_list = lambda l: feature_pb2.FeatureList(feature=l)
49feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
50sequence_example = example_pb2.SequenceExample
51
52
53class ParseExampleDatasetTest(test_base.DatasetTestBase,
54                              parameterized.TestCase):
55
56  def _compare_output_to_expected(self, dict_tensors, expected_tensors):
57    self.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
58
59    for k, v in sorted(dict_tensors.items()):
60      expected_v = expected_tensors[k]
61      self.assertValuesEqual(expected_v, v)
62
63  def _test(self,
64            input_tensor,
65            feature_val,
66            expected_values=None,
67            expected_err=None,
68            create_iterator_twice=False):
69
70    if expected_err:
71      with self.assertRaisesWithPredicateMatch(expected_err[0],
72                                               expected_err[1]):
73        dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
74            contrib_parsing_ops.parse_example_dataset(feature_val))
75        get_next = self.getNext(dataset)
76        self.evaluate(get_next())
77      return
78    else:
79      # Returns dict w/ Tensors and SparseTensors.
80      # Check values.
81      dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
82          contrib_parsing_ops.parse_example_dataset(feature_val))
83      get_next = self.getNext(dataset)
84      result = self.evaluate(get_next())
85      self._compare_output_to_expected(result, expected_values)
86      with self.assertRaises(errors_impl.OutOfRangeError):
87        self.evaluate(get_next())
88      with self.assertRaises(errors_impl.OutOfRangeError):
89        self.evaluate(get_next())
90      if create_iterator_twice:
91        get_next = self.getNext(dataset)
92        result = self.evaluate(get_next())
93        self._compare_output_to_expected(result, expected_values)
94        with self.assertRaises(errors_impl.OutOfRangeError):
95          self.evaluate(get_next())
96    # Check shapes; if serialized is a Tensor we need its size to
97    # properly check.
98    batch_size = (
99        self.evaluate(input_tensor).size if isinstance(input_tensor, ops.Tensor)
100        else np.asarray(input_tensor).size)
101    for k, f in feature_val.items():
102      if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
103        self.assertEqual(
104            dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[0],
105            batch_size)
106      elif isinstance(f, parsing_ops.VarLenFeature):
107        self.assertEqual(
108            dataset_ops.get_legacy_output_shapes(dataset)[k].as_list()[1], None)
109
110  @combinations.generate(test_base.default_test_combinations())
111  def testEmptySerializedWithAllDefaults(self):
112    sparse_name = "st_a"
113    a_name = "a"
114    b_name = "b"
115    c_name = "c:has_a_tricky_name"
116    a_default = [0, 42, 0]
117    b_default = np.random.rand(3, 3).astype(bytes)
118    c_default = np.random.rand(2).astype(np.float32)
119
120    expected_st_a = sparse_tensor.SparseTensorValue(  # indices, values, shape
121        np.empty((0, 2), dtype=np.int64),  # indices
122        np.empty((0,), dtype=np.int64),  # sp_a is DT_INT64
123        np.array([2, 0], dtype=np.int64))  # batch == 2, max_elems = 0
124
125    expected_output = {
126        sparse_name: expected_st_a,
127        a_name: np.array(2 * [[a_default]]),
128        b_name: np.array(2 * [b_default]),
129        c_name: np.array(2 * [c_default]),
130    }
131
132    self._test(
133        ops.convert_to_tensor(["", ""]), {
134            sparse_name:
135                parsing_ops.VarLenFeature(dtypes.int64),
136            a_name:
137                parsing_ops.FixedLenFeature(
138                    (1, 3), dtypes.int64, default_value=a_default),
139            b_name:
140                parsing_ops.FixedLenFeature(
141                    (3, 3), dtypes.string, default_value=b_default),
142            c_name:
143                parsing_ops.FixedLenFeature(
144                    (2,), dtypes.float32, default_value=c_default),
145        },
146        expected_values=expected_output,
147        create_iterator_twice=True)
148
149  @combinations.generate(test_base.graph_only_combinations())
150  def testEmptySerializedWithoutDefaultsShouldFail(self):
151    input_features = {
152        "st_a":
153            parsing_ops.VarLenFeature(dtypes.int64),
154        "a":
155            parsing_ops.FixedLenFeature(
156                (1, 3), dtypes.int64, default_value=[0, 42, 0]),
157        "b":
158            parsing_ops.FixedLenFeature(
159                (3, 3),
160                dtypes.string,
161                default_value=np.random.rand(3, 3).astype(bytes)),
162        # Feature "c" is missing a default, this gap will cause failure.
163        "c":
164            parsing_ops.FixedLenFeature(
165                (2,), dtype=dtypes.float32),
166    }
167
168    # Edge case where the key is there but the feature value is empty
169    original = example(features=features({"c": feature()}))
170    self._test(
171        [original.SerializeToString()],
172        input_features,
173        expected_err=(errors_impl.InvalidArgumentError,
174                      "Feature: c \\(data type: float\\) is required"))
175
176    # Standard case of missing key and value.
177    self._test(
178        ["", ""],
179        input_features,
180        expected_err=(errors_impl.InvalidArgumentError,
181                      "Feature: c \\(data type: float\\) is required"))
182
183  @combinations.generate(test_base.graph_only_combinations())
184  def testDenseNotMatchingShapeShouldFail(self):
185    original = [
186        example(features=features({
187            "a": float_feature([1, 1, 3]),
188        })), example(features=features({
189            "a": float_feature([-1, -1]),
190        }))
191    ]
192
193    serialized = [m.SerializeToString() for m in original]
194
195    self._test(
196        ops.convert_to_tensor(serialized),
197        {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
198        expected_err=(errors_impl.InvalidArgumentError,
199                      "Key: a, Index: 1.  Number of float values"))
200
201  @combinations.generate(test_base.default_test_combinations())
202  def testDenseDefaultNoShapeShouldFail(self):
203    original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
204
205    serialized = [m.SerializeToString() for m in original]
206
207    self._test(
208        ops.convert_to_tensor(serialized),
209        {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
210        expected_err=(ValueError, "Missing shape for feature a"))
211
212  @combinations.generate(test_base.default_test_combinations())
213  def testSerializedContainingSparse(self):
214    original = [
215        example(features=features({
216            "st_c": float_feature([3, 4])
217        })),
218        example(features=features({
219            "st_c": float_feature([]),  # empty float list
220        })),
221        example(features=features({
222            "st_d": feature(),  # feature with nothing in it
223        })),
224        example(features=features({
225            "st_c": float_feature([1, 2, -1]),
226            "st_d": bytes_feature([b"hi"])
227        }))
228    ]
229
230    serialized = [m.SerializeToString() for m in original]
231
232    expected_st_c = sparse_tensor.SparseTensorValue(  # indices, values, shape
233        np.array([[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64),
234        np.array([3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32),
235        np.array([4, 3], dtype=np.int64))  # batch == 2, max_elems = 3
236
237    expected_st_d = sparse_tensor.SparseTensorValue(  # indices, values, shape
238        np.array([[3, 0]], dtype=np.int64), np.array(["hi"], dtype=bytes),
239        np.array([4, 1], dtype=np.int64))  # batch == 2, max_elems = 1
240
241    expected_output = {
242        "st_c": expected_st_c,
243        "st_d": expected_st_d,
244    }
245
246    self._test(
247        ops.convert_to_tensor(serialized), {
248            "st_c": parsing_ops.VarLenFeature(dtypes.float32),
249            "st_d": parsing_ops.VarLenFeature(dtypes.string)
250        },
251        expected_values=expected_output,
252        create_iterator_twice=True)
253
254  @combinations.generate(test_base.default_test_combinations())
255  def testSerializedContainingSparseFeature(self):
256    original = [
257        example(features=features({
258            "val": float_feature([3, 4]),
259            "idx": int64_feature([5, 10])
260        })),
261        example(features=features({
262            "val": float_feature([]),  # empty float list
263            "idx": int64_feature([])
264        })),
265        example(features=features({
266            "val": feature(),  # feature with nothing in it
267            # missing idx feature
268        })),
269        example(features=features({
270            "val": float_feature([1, 2, -1]),
271            "idx":
272                int64_feature([0, 9, 3])  # unsorted
273        }))
274    ]
275
276    serialized = [m.SerializeToString() for m in original]
277
278    expected_sp = sparse_tensor.SparseTensorValue(  # indices, values, shape
279        np.array([[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
280        np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
281        np.array([4, 13], dtype=np.int64))  # batch == 4, max_elems = 13
282
283    expected_output = {"sp": expected_sp,}
284
285    self._test(
286        ops.convert_to_tensor(serialized),
287        {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
288        expected_values=expected_output,
289        create_iterator_twice=True)
290
291  @combinations.generate(test_base.default_test_combinations())
292  def testSerializedContainingSparseFeatureReuse(self):
293    original = [
294        example(features=features({
295            "val1": float_feature([3, 4]),
296            "val2": float_feature([5, 6]),
297            "idx": int64_feature([5, 10])
298        })),
299        example(features=features({
300            "val1": float_feature([]),  # empty float list
301            "idx": int64_feature([])
302        })),
303    ]
304
305    serialized = [m.SerializeToString() for m in original]
306
307    expected_sp1 = sparse_tensor.SparseTensorValue(  # indices, values, shape
308        np.array([[0, 5], [0, 10]], dtype=np.int64),
309        np.array([3.0, 4.0], dtype=np.float32),
310        np.array([2, 13], dtype=np.int64))  # batch == 2, max_elems = 13
311
312    expected_sp2 = sparse_tensor.SparseTensorValue(  # indices, values, shape
313        np.array([[0, 5], [0, 10]], dtype=np.int64),
314        np.array([5.0, 6.0], dtype=np.float32),
315        np.array([2, 7], dtype=np.int64))  # batch == 2, max_elems = 13
316
317    expected_output = {
318        "sp1": expected_sp1,
319        "sp2": expected_sp2,
320    }
321
322    self._test(
323        ops.convert_to_tensor(serialized), {
324            "sp1":
325                parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
326            "sp2":
327                parsing_ops.SparseFeature(
328                    "idx", "val2", dtypes.float32, size=7, already_sorted=True)
329        },
330        expected_values=expected_output,
331        create_iterator_twice=True)
332
333  @combinations.generate(test_base.default_test_combinations())
334  def testSerializedContaining3DSparseFeature(self):
335    original = [
336        example(features=features({
337            "val": float_feature([3, 4]),
338            "idx0": int64_feature([5, 10]),
339            "idx1": int64_feature([0, 2]),
340        })),
341        example(features=features({
342            "val": float_feature([]),  # empty float list
343            "idx0": int64_feature([]),
344            "idx1": int64_feature([]),
345        })),
346        example(features=features({
347            "val": feature(),  # feature with nothing in it
348            # missing idx feature
349        })),
350        example(features=features({
351            "val": float_feature([1, 2, -1]),
352            "idx0": int64_feature([0, 9, 3]),  # unsorted
353            "idx1": int64_feature([1, 0, 2]),
354        }))
355    ]
356
357    serialized = [m.SerializeToString() for m in original]
358
359    expected_sp = sparse_tensor.SparseTensorValue(
360        # indices
361        np.array([[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
362                 dtype=np.int64),
363        # values
364        np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
365        # shape batch == 4, max_elems = 13
366        np.array([4, 13, 3], dtype=np.int64))
367
368    expected_output = {"sp": expected_sp,}
369
370    self._test(
371        ops.convert_to_tensor(serialized), {
372            "sp":
373                parsing_ops.SparseFeature(["idx0", "idx1"], "val",
374                                          dtypes.float32, [13, 3])
375        },
376        expected_values=expected_output,
377        create_iterator_twice=True)
378
379  @combinations.generate(test_base.default_test_combinations())
380  def testSerializedContainingDense(self):
381    aname = "a"
382    bname = "b*has+a:tricky_name"
383    original = [
384        example(features=features({
385            aname: float_feature([1, 1]),
386            bname: bytes_feature([b"b0_str"]),
387        })), example(features=features({
388            aname: float_feature([-1, -1]),
389            bname: bytes_feature([b""]),
390        }))
391    ]
392
393    serialized = [m.SerializeToString() for m in original]
394
395    expected_output = {
396        aname:
397            np.array(  # pylint: disable=too-many-function-args
398                [[1, 1], [-1, -1]],
399                dtype=np.float32).reshape(2, 1, 2, 1),
400        bname:
401            np.array(  # pylint: disable=too-many-function-args
402                ["b0_str", ""],
403                dtype=bytes).reshape(2, 1, 1, 1, 1),
404    }
405
406    # No defaults, values required
407    self._test(
408        ops.convert_to_tensor(serialized), {
409            aname:
410                parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
411            bname:
412                parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
413        },
414        expected_values=expected_output,
415        create_iterator_twice=True)
416
417  # This test is identical as the previous one except
418  # for the creation of 'serialized'.
419  @combinations.generate(test_base.default_test_combinations())
420  def testSerializedContainingDenseWithConcat(self):
421    aname = "a"
422    bname = "b*has+a:tricky_name"
423    # TODO(lew): Feature appearing twice should be an error in future.
424    original = [
425        (example(features=features({
426            aname: float_feature([10, 10]),
427        })), example(features=features({
428            aname: float_feature([1, 1]),
429            bname: bytes_feature([b"b0_str"]),
430        }))),
431        (
432            example(features=features({
433                bname: bytes_feature([b"b100"]),
434            })),
435            example(features=features({
436                aname: float_feature([-1, -1]),
437                bname: bytes_feature([b"b1"]),
438            })),),
439    ]
440
441    serialized = [
442        m.SerializeToString() + n.SerializeToString() for (m, n) in original
443    ]
444
445    expected_output = {
446        aname:
447            np.array(  # pylint: disable=too-many-function-args
448                [[1, 1], [-1, -1]],
449                dtype=np.float32).reshape(2, 1, 2, 1),
450        bname:
451            np.array(  # pylint: disable=too-many-function-args
452                ["b0_str", "b1"],
453                dtype=bytes).reshape(2, 1, 1, 1, 1),
454    }
455
456    # No defaults, values required
457    self._test(
458        ops.convert_to_tensor(serialized), {
459            aname:
460                parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
461            bname:
462                parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
463        },
464        expected_values=expected_output,
465        create_iterator_twice=True)
466
467  @combinations.generate(test_base.default_test_combinations())
468  def testSerializedContainingDenseScalar(self):
469    original = [
470        example(features=features({
471            "a": float_feature([1]),
472        })), example(features=features({}))
473    ]
474
475    serialized = [m.SerializeToString() for m in original]
476
477    expected_output = {
478        "a":
479            np.array(
480                [[1], [-1]], dtype=np.float32)  # 2x1 (column vector)
481    }
482
483    self._test(
484        ops.convert_to_tensor(serialized), {
485            "a":
486                parsing_ops.FixedLenFeature(
487                    (1,), dtype=dtypes.float32, default_value=-1),
488        },
489        expected_values=expected_output,
490        create_iterator_twice=True)
491
492  @combinations.generate(test_base.default_test_combinations())
493  def testSerializedContainingDenseWithDefaults(self):
494    original = [
495        example(features=features({
496            "a": float_feature([1, 1]),
497        })),
498        example(features=features({
499            "b": bytes_feature([b"b1"]),
500        })),
501        example(features=features({
502            "b": feature()
503        })),
504    ]
505
506    serialized = [m.SerializeToString() for m in original]
507
508    expected_output = {
509        "a":
510            np.array(  # pylint: disable=too-many-function-args
511                [[1, 1], [3, -3], [3, -3]],
512                dtype=np.float32).reshape(3, 1, 2, 1),
513        "b":
514            np.array(  # pylint: disable=too-many-function-args
515                ["tmp_str", "b1", "tmp_str"],
516                dtype=bytes).reshape(3, 1, 1, 1, 1),
517    }
518
519    self._test(
520        ops.convert_to_tensor(serialized), {
521            "a":
522                parsing_ops.FixedLenFeature(
523                    (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
524            "b":
525                parsing_ops.FixedLenFeature(
526                    (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
527        },
528        expected_values=expected_output,
529        create_iterator_twice=True)
530
531  @combinations.generate(test_base.default_test_combinations())
532  def testSerializedSparseAndSparseFeatureAndDenseWithNoDefault(self):
533    expected_st_a = sparse_tensor.SparseTensorValue(  # indices, values, shape
534        np.empty((0, 2), dtype=np.int64),  # indices
535        np.empty((0,), dtype=np.int64),  # sp_a is DT_INT64
536        np.array([2, 0], dtype=np.int64))  # batch == 2, max_elems = 0
537    expected_sp = sparse_tensor.SparseTensorValue(  # indices, values, shape
538        np.array([[0, 0], [0, 3], [1, 7]], dtype=np.int64),
539        np.array(["a", "b", "c"], dtype="|S"),
540        np.array([2, 13], dtype=np.int64))  # batch == 4, max_elems = 13
541
542    original = [
543        example(features=features({
544            "c": float_feature([3, 4]),
545            "val": bytes_feature([b"a", b"b"]),
546            "idx": int64_feature([0, 3])
547        })), example(features=features({
548            "c": float_feature([1, 2]),
549            "val": bytes_feature([b"c"]),
550            "idx": int64_feature([7])
551        }))
552    ]
553
554    serialized = [m.SerializeToString() for m in original]
555
556    a_default = [1, 2, 3]
557    b_default = np.random.rand(3, 3).astype(bytes)
558    expected_output = {
559        "st_a": expected_st_a,
560        "sp": expected_sp,
561        "a": np.array(2 * [[a_default]]),
562        "b": np.array(2 * [b_default]),
563        "c": np.array(
564            [[3, 4], [1, 2]], dtype=np.float32),
565    }
566
567    self._test(
568        ops.convert_to_tensor(serialized),
569        {
570            "st_a":
571                parsing_ops.VarLenFeature(dtypes.int64),
572            "sp":
573                parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
574            "a":
575                parsing_ops.FixedLenFeature(
576                    (1, 3), dtypes.int64, default_value=a_default),
577            "b":
578                parsing_ops.FixedLenFeature(
579                    (3, 3), dtypes.string, default_value=b_default),
580            # Feature "c" must be provided, since it has no default_value.
581            "c":
582                parsing_ops.FixedLenFeature((2,), dtypes.float32),
583        },
584        expected_values=expected_output,
585        create_iterator_twice=True)
586
587  @combinations.generate(test_base.default_test_combinations())
588  def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
589    expected_idx = sparse_tensor.SparseTensorValue(  # indices, values, shape
590        np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
591        np.array([0, 3, 7, 1]),
592        np.array([2, 2], dtype=np.int64))  # batch == 4, max_elems = 2
593
594    expected_sp = sparse_tensor.SparseTensorValue(  # indices, values, shape
595        np.array([[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64),
596        np.array(["a", "b", "d", "c"], dtype="|S"),
597        np.array([2, 13], dtype=np.int64))  # batch == 4, max_elems = 13
598
599    original = [
600        example(features=features({
601            "val": bytes_feature([b"a", b"b"]),
602            "idx": int64_feature([0, 3])
603        })), example(features=features({
604            "val": bytes_feature([b"c", b"d"]),
605            "idx": int64_feature([7, 1])
606        }))
607    ]
608
609    serialized = [m.SerializeToString() for m in original]
610
611    expected_output = {
612        "idx": expected_idx,
613        "sp": expected_sp,
614    }
615
616    self._test(
617        ops.convert_to_tensor(serialized), {
618            "idx":
619                parsing_ops.VarLenFeature(dtypes.int64),
620            "sp":
621                parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
622        },
623        expected_values=expected_output,
624        create_iterator_twice=True)
625
626  @combinations.generate(
627      combinations.times(test_base.default_test_combinations(),
628                         combinations.combine(batch_size=[1, 10, 20, 100, 256]))
629  )
630  def testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
631    np.random.seed(3456)
632    # During parsing, data read from the serialized proto is stored in buffers.
633    # For small batch sizes, a buffer will contain one minibatch entry.
634    # For larger batch sizes, a buffer may contain several minibatch
635    # entries.  This test identified a bug where the code that copied
636    # data out of the buffers and into the output tensors assumed each
637    # buffer only contained one minibatch entry.  The bug has since been fixed.
638    truth_int = [i for i in range(batch_size)]
639    truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
640                 for i in range(batch_size)]
641
642    expected_str = copy.deepcopy(truth_str)
643
644    # Delete some intermediate entries
645    for i in range(batch_size):
646      col = 1
647      if np.random.rand() < 0.25:
648        # w.p. 25%, drop out the second entry
649        expected_str[i][col] = b"default"
650        col -= 1
651        truth_str[i].pop()
652      if np.random.rand() < 0.25:
653        # w.p. 25%, drop out the second entry (possibly again)
654        expected_str[i][col] = b"default"
655        truth_str[i].pop()
656
657    expected_output = {
658        # Batch size batch_size, 1 time step.
659        "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
660        # Batch size batch_size, 2 time steps.
661        "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
662    }
663
664    original = [
665        example(features=features(
666            {"a": int64_feature([truth_int[i]]),
667             "b": bytes_feature(truth_str[i])}))
668        for i in range(batch_size)
669    ]
670
671    serialized = [m.SerializeToString() for m in original]
672
673    self._test(
674        ops.convert_to_tensor(serialized, dtype=dtypes.string), {
675            "a":
676                parsing_ops.FixedLenSequenceFeature(
677                    shape=(),
678                    dtype=dtypes.int64,
679                    allow_missing=True,
680                    default_value=-1),
681            "b":
682                parsing_ops.FixedLenSequenceFeature(
683                    shape=[],
684                    dtype=dtypes.string,
685                    allow_missing=True,
686                    default_value="default"),
687        },
688        expected_values=expected_output,
689        create_iterator_twice=True)
690
691  @combinations.generate(test_base.default_test_combinations())
692  def testSerializedShapeMismatch(self):
693    aname = "a"
694    bname = "b"
695    cname = "c"
696    original = [
697        example(features=features({
698            cname: int64_feature([2]),
699        })),
700        example(features=features({
701            aname: float_feature([1, 1]),
702            bname: bytes_feature([b"b0_str", b"b1_str"]),
703        })),
704        example(features=features({
705            aname: float_feature([-1, -1, 2, 2]),
706            bname: bytes_feature([b"b1"]),
707        })),
708        example(features=features({
709            aname: float_feature([]),
710            cname: int64_feature([3]),
711        })),
712    ]
713
714    serialized = [m.SerializeToString() for m in original]
715    if context.executing_eagerly():
716      self._test(
717          ops.convert_to_tensor(serialized), {
718              aname:
719                  parsing_ops.FixedLenSequenceFeature((2, 1),
720                                                      dtype=dtypes.float32,
721                                                      allow_missing=True,
722                                                      default_value=[]),
723              bname:
724                  parsing_ops.FixedLenSequenceFeature(
725                      (2, 1, 1), dtype=dtypes.string, allow_missing=True),
726          },
727          expected_err=(errors_impl.InvalidArgumentError,
728                        "Input to reshape is a tensor with 0 values"))
729    else:
730      self._test(
731          ops.convert_to_tensor(serialized), {
732              aname:
733                  parsing_ops.FixedLenSequenceFeature((2, 1),
734                                                      dtype=dtypes.float32,
735                                                      allow_missing=True,
736                                                      default_value=[]),
737              bname:
738                  parsing_ops.FixedLenSequenceFeature(
739                      (2, 1, 1), dtype=dtypes.string, allow_missing=True),
740          },
741          expected_err=(ValueError,
742                        "Cannot reshape a tensor with 0 elements to shape"))
743
744  @combinations.generate(test_base.graph_only_combinations())
745  def testSerializedContainingVarLenDense(self):
746    aname = "a"
747    bname = "b"
748    cname = "c"
749    dname = "d"
750    original = [
751        example(features=features({
752            cname: int64_feature([2]),
753        })),
754        example(
755            features=features({
756                aname: float_feature([1, 1]),
757                bname: bytes_feature([b"b0_str", b"b1_str"]),
758            })),
759        example(
760            features=features({
761                aname: float_feature([-1, -1, 2, 2]),
762                bname: bytes_feature([b"b1"]),
763            })),
764        example(
765            features=features({
766                aname: float_feature([]),
767                cname: int64_feature([3]),
768            })),
769    ]
770
771    serialized = [m.SerializeToString() for m in original]
772
773    expected_output = {
774        aname:
775            np.array(  # pylint: disable=too-many-function-args
776                [
777                    [0, 0, 0, 0],
778                    [1, 1, 0, 0],
779                    [-1, -1, 2, 2],
780                    [0, 0, 0, 0],
781                ],
782                dtype=np.float32).reshape(4, 2, 2, 1),
783        bname:
784            np.array(  # pylint: disable=too-many-function-args
785                [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
786                dtype=bytes).reshape(4, 2, 1, 1, 1),
787        cname:
788            np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
789        dname:
790            np.empty(shape=(4, 0), dtype=bytes),
791    }
792
793    self._test(
794        ops.convert_to_tensor(serialized), {
795            aname:
796                parsing_ops.FixedLenSequenceFeature(
797                    (2, 1), dtype=dtypes.float32, allow_missing=True),
798            bname:
799                parsing_ops.FixedLenSequenceFeature(
800                    (1, 1, 1), dtype=dtypes.string, allow_missing=True),
801            cname:
802                parsing_ops.FixedLenSequenceFeature(
803                    shape=[], dtype=dtypes.int64, allow_missing=True),
804            dname:
805                parsing_ops.FixedLenSequenceFeature(
806                    shape=[], dtype=dtypes.string, allow_missing=True),
807        },
808        expected_values=expected_output,
809        create_iterator_twice=True)
810
811    # Test with padding values.
812    expected_output_custom_padding = dict(expected_output)
813    expected_output_custom_padding[aname] = np.array(  # pylint: disable=too-many-function-args
814        [
815            [-2, -2, -2, -2],
816            [1, 1, -2, -2],
817            [-1, -1, 2, 2],
818            [-2, -2, -2, -2],
819        ],
820        dtype=np.float32).reshape(4, 2, 2, 1)
821
822    self._test(
823        ops.convert_to_tensor(serialized), {
824            aname:
825                parsing_ops.FixedLenSequenceFeature(
826                    (2, 1),
827                    dtype=dtypes.float32,
828                    allow_missing=True,
829                    default_value=-2.0),
830            bname:
831                parsing_ops.FixedLenSequenceFeature(
832                    (1, 1, 1), dtype=dtypes.string, allow_missing=True),
833            cname:
834                parsing_ops.FixedLenSequenceFeature(
835                    shape=[], dtype=dtypes.int64, allow_missing=True),
836            dname:
837                parsing_ops.FixedLenSequenceFeature(
838                    shape=[], dtype=dtypes.string, allow_missing=True),
839        }, expected_output_custom_padding)
840
841    # Change number of required values so the inputs are not a
842    # multiple of this size.
843    self._test(
844        ops.convert_to_tensor(serialized), {
845            aname:
846                parsing_ops.FixedLenSequenceFeature(
847                    (2, 1), dtype=dtypes.float32, allow_missing=True),
848            bname:
849                parsing_ops.FixedLenSequenceFeature(
850                    (2, 1, 1), dtype=dtypes.string, allow_missing=True),
851        },
852        expected_err=(
853            errors_impl.OpError, "Key: b, Index: 2.  "
854            "Number of bytes values is not a multiple of stride length."))
855
856    self._test(
857        ops.convert_to_tensor(serialized), {
858            aname:
859                parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
860            bname:
861                parsing_ops.FixedLenSequenceFeature(
862                    (2, 1, 1), dtype=dtypes.string, allow_missing=True),
863        },
864        expected_err=(ValueError,
865                      "First dimension of shape for feature a unknown. "
866                      "Consider using FixedLenSequenceFeature."))
867
868    self._test(
869        ops.convert_to_tensor(serialized), {
870            cname:
871                parsing_ops.FixedLenFeature(
872                    (1, None), dtype=dtypes.int64, default_value=[[1]]),
873        },
874        expected_err=(ValueError,
875                      "All dimensions of shape for feature c need to be known "
876                      r"but received \(1, None\)."))
877
878    self._test(
879        ops.convert_to_tensor(serialized), {
880            aname:
881                parsing_ops.FixedLenSequenceFeature(
882                    (2, 1), dtype=dtypes.float32, allow_missing=True),
883            bname:
884                parsing_ops.FixedLenSequenceFeature(
885                    (1, 1, 1), dtype=dtypes.string, allow_missing=True),
886            cname:
887                parsing_ops.FixedLenSequenceFeature(
888                    shape=[], dtype=dtypes.int64, allow_missing=False),
889            dname:
890                parsing_ops.FixedLenSequenceFeature(
891                    shape=[], dtype=dtypes.string, allow_missing=True),
892        },
893        expected_err=(ValueError,
894                      "Unsupported: FixedLenSequenceFeature requires "
895                      "allow_missing to be True."))
896
897  @combinations.generate(test_base.default_test_combinations())
898  def testSerializedContainingRaggedFeatureWithNoPartitions(self):
899    original = [
900        example(
901            features=features({
902                "rt_c": float_feature([3, 4, 5, 6, 7, 8]),
903                "rt_f_values": float_feature([0, 1, 2, 3, 4]),
904            })),
905        example(
906            features=features({
907                "rt_c": float_feature([]),  # empty float list
908            })),
909        example(
910            features=features({
911                "rt_d": feature(),  # feature with nothing in it
912            })),
913        example(
914            features=features({
915                "rt_c": float_feature([1, 2, -1]),
916                "rt_d": bytes_feature([b"hi"]),
917                "rt_f_values": float_feature([0, 1, 2]),
918            }))
919    ]
920
921    serialized = [m.SerializeToString() for m in original]
922
923    expected_rt_c = ragged_factory_ops.constant_value(
924        [[3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [], [], [1.0, 2.0, -1.0]],
925        row_splits_dtype=dtypes.int32)
926    expected_rt_d = ragged_factory_ops.constant_value(
927        [[], [], [], [b"hi"]], row_splits_dtype=dtypes.int64)
928    expected_rt_f = ragged_factory_ops.constant_value(
929        [[0.0, 1.0, 2.0, 3.0, 4.0], [], [], [0.0, 1.0, 2.0]],
930        row_splits_dtype=dtypes.int32)
931
932    expected_output = {
933        "rt_c": expected_rt_c,
934        "rt_d": expected_rt_d,
935        "rt_f": expected_rt_f,
936    }
937
938    self._test(
939        ops.convert_to_tensor(serialized), {
940            "rt_c":
941                parsing_ops.RaggedFeature(dtypes.float32),
942            "rt_d":
943                parsing_ops.RaggedFeature(
944                    dtypes.string, row_splits_dtype=dtypes.int64),
945            "rt_f":
946                parsing_ops.RaggedFeature(
947                    dtypes.float32, value_key="rt_f_values"),
948        },
949        expected_values=expected_output,
950        create_iterator_twice=True)
951
952  @combinations.generate(test_base.default_test_combinations())
953  def testSerializedContainingRaggedFeatureWithOnePartition(self):
954    original = [
955        example(
956            features=features({
957                # rt = [[3], [4, 5, 6]]
958                "rt_values": float_feature([3, 4, 5, 6]),
959                "rt_splits": int64_feature([0, 1, 4]),
960                "rt_lengths": int64_feature([1, 3]),
961                "rt_starts": int64_feature([0, 1]),
962                "rt_limits": int64_feature([1, 4]),
963                "rt_rowids": int64_feature([0, 1, 1, 1]),
964            })),
965        example(
966            features=features({
967                # rt = []
968                "rt_values": float_feature([]),
969                "rt_splits": int64_feature([0]),
970                "rt_lengths": int64_feature([]),
971                "rt_starts": int64_feature([]),
972                "rt_limits": int64_feature([]),
973                "rt_rowids": int64_feature([]),
974            })),
975        example(
976            features=features({
977                # rt = []
978                "rt_values": feature(),  # feature with nothing in it
979                "rt_splits": int64_feature([0]),
980                "rt_lengths": feature(),
981                "rt_starts": feature(),
982                "rt_limits": feature(),
983                "rt_rowids": feature(),
984            })),
985        example(
986            features=features({
987                # rt = [[1.0, 2.0, -1.0], [], [8.0, 9.0], [5.0]]
988                "rt_values": float_feature([1, 2, -1, 8, 9, 5]),
989                "rt_splits": int64_feature([0, 3, 3, 5, 6]),
990                "rt_lengths": int64_feature([3, 0, 2, 1]),
991                "rt_starts": int64_feature([0, 3, 3, 5]),
992                "rt_limits": int64_feature([3, 3, 5, 6]),
993                "rt_rowids": int64_feature([0, 0, 0, 2, 2, 3]),
994            }))
995    ]
996    serialized = [m.SerializeToString() for m in original]
997
998    test_features = {
999        "rt1":
1000            parsing_ops.RaggedFeature(
1001                value_key="rt_values",
1002                partitions=[parsing_ops.RaggedFeature.RowSplits("rt_splits")],
1003                dtype=dtypes.float32),
1004        "rt2":
1005            parsing_ops.RaggedFeature(
1006                value_key="rt_values",
1007                partitions=[parsing_ops.RaggedFeature.RowLengths("rt_lengths")],
1008                dtype=dtypes.float32),
1009        "rt3":
1010            parsing_ops.RaggedFeature(
1011                value_key="rt_values",
1012                partitions=[parsing_ops.RaggedFeature.RowStarts("rt_starts")],
1013                dtype=dtypes.float32),
1014        "rt4":
1015            parsing_ops.RaggedFeature(
1016                value_key="rt_values",
1017                partitions=[parsing_ops.RaggedFeature.RowLimits("rt_limits")],
1018                dtype=dtypes.float32),
1019        "rt5":
1020            parsing_ops.RaggedFeature(
1021                value_key="rt_values",
1022                partitions=[parsing_ops.RaggedFeature.ValueRowIds("rt_rowids")],
1023                dtype=dtypes.float32),
1024        "uniform1":
1025            parsing_ops.RaggedFeature(
1026                value_key="rt_values",
1027                partitions=[parsing_ops.RaggedFeature.UniformRowLength(2)],
1028                dtype=dtypes.float32),
1029        "uniform2":
1030            parsing_ops.RaggedFeature(
1031                value_key="rt_values",
1032                partitions=[
1033                    parsing_ops.RaggedFeature.UniformRowLength(2),
1034                    parsing_ops.RaggedFeature.RowSplits("rt_splits")
1035                ],
1036                dtype=dtypes.float32),
1037    }
1038
1039    expected_rt = ragged_factory_ops.constant(
1040        [[[3], [4, 5, 6]], [], [], [[1, 2, -1], [], [8, 9], [5]]],
1041        dtype=dtypes.float32,
1042        row_splits_dtype=dtypes.int32)
1043
1044    expected_uniform1 = ragged_factory_ops.constant(
1045        [[[3, 4], [5, 6]], [], [], [[1, 2], [-1, 8], [9, 5]]],
1046        ragged_rank=1,
1047        dtype=dtypes.float32,
1048        row_splits_dtype=dtypes.int32)
1049
1050    expected_uniform2 = ragged_factory_ops.constant(
1051        [[[[3], [4, 5, 6]]], [], [], [[[1, 2, -1], []], [[8, 9], [5]]]],
1052        dtype=dtypes.float32,
1053        row_splits_dtype=dtypes.int32)
1054
1055    expected_output = {
1056        "rt1": expected_rt,
1057        "rt2": expected_rt,
1058        "rt3": expected_rt,
1059        "rt4": expected_rt,
1060        "rt5": expected_rt,
1061        "uniform1": expected_uniform1,
1062        "uniform2": expected_uniform2,
1063    }
1064
1065    self._test(
1066        ops.convert_to_tensor(serialized),
1067        test_features,
1068        expected_values=expected_output,
1069        create_iterator_twice=True)
1070
1071  @combinations.generate(test_base.default_test_combinations())
1072  def testSerializedContainingRaggedFeatureWithMultiplePartitions(self):
1073    original = [
1074        # rt shape: [(batch), 2, None, None]
1075        example(
1076            features=features({
1077                # rt = [[[[1]], [[2, 3], [4]]], [[], [[5, 6, 7]]]]
1078                "rt_values": float_feature([1, 2, 3, 4, 5, 6, 7]),
1079                "lengths_axis2": int64_feature([1, 2, 0, 1]),
1080                "lengths_axis3": int64_feature([1, 2, 1, 3]),
1081                "splits_axis3": int64_feature([0, 1, 3, 4, 7]),
1082            })),
1083        example(
1084            features=features({
1085                # rt = [[[[1, 2, 3], [4]], [[5], [6], [7, 8]]]]
1086                "rt_values": float_feature([1, 2, 3, 4, 5, 6, 7, 8]),
1087                "lengths_axis2": int64_feature([2, 3]),
1088                "lengths_axis3": int64_feature([3, 1, 1, 1, 2]),
1089                "splits_axis3": int64_feature([0, 3, 4, 5, 6, 8]),
1090            }))
1091    ]
1092    serialized = [m.SerializeToString() for m in original]
1093
1094    test_features = {
1095        "rt1":
1096            parsing_ops.RaggedFeature(
1097                value_key="rt_values",
1098                partitions=[
1099                    parsing_ops.RaggedFeature.UniformRowLength(2),
1100                    parsing_ops.RaggedFeature.RowLengths("lengths_axis2"),
1101                    parsing_ops.RaggedFeature.RowSplits("splits_axis3"),
1102                ],
1103                dtype=dtypes.float32,
1104                row_splits_dtype=dtypes.int64,
1105            ),
1106    }
1107
1108    expected_rt = ragged_factory_ops.constant(
1109        [[[[[1]], [[2, 3], [4]]], [[], [[5, 6, 7]]]],
1110         [[[[1, 2, 3], [4]], [[5], [6], [7, 8]]]]],
1111        dtype=dtypes.float32,
1112        row_splits_dtype=dtypes.int64)
1113
1114    expected_output = {
1115        "rt1": expected_rt,
1116    }
1117
1118    self._test(
1119        ops.convert_to_tensor(serialized),
1120        test_features,
1121        expected_values=expected_output,
1122        create_iterator_twice=True)
1123
1124  @combinations.generate(
1125      combinations.times(
1126          test_base.default_test_combinations(),
1127          combinations.combine(
1128              local_determinism=[None, True, False],
1129              global_determinism=[True, False])))
1130  def testDeterminism(self, local_determinism, global_determinism):
1131    num_elements = 1000
1132    batches = []
1133    for i in range(num_elements):
1134      example_i = example(features=features({
1135          "a": int64_feature([i]),
1136      }))
1137      batches.append([example_i.SerializeToString()])
1138
1139    test_features = {"a": parsing_ops.FixedLenFeature((), dtype=dtypes.int64)}
1140    dataset = dataset_ops.Dataset.from_tensor_slices(batches)
1141    dataset = dataset.apply(
1142        contrib_parsing_ops.parse_example_dataset(
1143            test_features,
1144            num_parallel_calls=10,
1145            deterministic=local_determinism))
1146
1147    opts = options_lib.Options()
1148    opts.deterministic = global_determinism
1149    dataset = dataset.with_options(opts)
1150
1151    expected = list(range(num_elements))
1152    actual = [elem["a"][0] for elem in self.getDatasetOutput(dataset)]
1153
1154    require_order = local_determinism or (local_determinism is None and
1155                                          global_determinism)
1156    if require_order:
1157      self.assertAllEqual(expected, actual)
1158    else:
1159      self.assertCountEqual(expected, actual)
1160
1161
1162class ParseExampleDatasetCheckpointTest(tf_record_test_base.FeaturesTestBase,
1163                                        checkpoint_test_base.CheckpointTestBase,
1164                                        parameterized.TestCase):
1165
1166  def _parse_example_dataset(self, num_repeat, batch_size):
1167    return self.make_batch_feature(
1168        filenames=self._filenames,
1169        num_epochs=num_repeat,
1170        batch_size=batch_size,
1171        reader_num_threads=5,
1172        parser_num_threads=10)
1173
1174  @combinations.generate(
1175      combinations.times(test_base.default_test_combinations(),
1176                         checkpoint_test_base.default_test_combinations()))
1177  def test(self, verify_fn):
1178    num_repeat = 5
1179    batch_size = 2
1180    num_outputs = self._num_records * self._num_files * num_repeat // batch_size
1181    # pylint: disable=g-long-lambda
1182    verify_fn(
1183        self, lambda: self._parse_example_dataset(
1184            num_repeat=num_repeat, batch_size=batch_size), num_outputs)
1185
1186
1187if __name__ == "__main__":
1188  test.main()
1189