xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/structured/structured_tensor_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for StructuredTensor."""
16
17import textwrap
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import extension_type
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops.ragged import ragged_factory_ops
33from tensorflow.python.ops.ragged import ragged_tensor
34from tensorflow.python.ops.ragged import row_partition
35from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape
36
37# TODO(b/173144447): remove when structured_array_ops is included in init.
38from tensorflow.python.ops.structured import structured_array_ops  # pylint: disable=unused-import
39
40from tensorflow.python.ops.structured import structured_tensor
41from tensorflow.python.ops.structured import structured_tensor_dynamic
42from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
43from tensorflow.python.platform import googletest
44from tensorflow.python.util import dispatch
45
46
47class _PrivateSpecialType(extension_type.ExtensionType):
48  ragged: ragged_tensor.RaggedTensor
49
50
51@dispatch.dispatch_for_types(array_ops.shape_v2, _PrivateSpecialType)
52def shape_v2_special(input: _PrivateSpecialType, out_type=dtypes.int32,  # pylint: disable=redefined-builtin
53                     name=None):
54  """Returns a DynamicRaggedShape containing the shape of the input."""
55  del name
56  return array_ops.shape_v2(input.ragged, out_type)  # pylint: disable=protected-access
57
58
59class _PrivateBrokenType(extension_type.ExtensionType):
60  ragged: ragged_tensor.RaggedTensor
61
62
63@dispatch.dispatch_for_types(array_ops.shape_v2, _PrivateBrokenType)
64def shape_v2_broken(input: _PrivateBrokenType, out_type=dtypes.int32,  # pylint: disable=redefined-builtin
65                    name=None):
66  """Returns a DynamicRaggedShape containing the shape of the input."""
67  del name
68  del input
69  del out_type
70  return {
71      "foo": "This is not a shape",
72      "bar": "But if I put a string here, it becomes a vector"
73  }
74
75
76# pylint: disable=g-long-lambda
77@test_util.run_all_in_graph_and_eager_modes
78class StructuredTensorTest(test_util.TensorFlowTestCase,
79                           parameterized.TestCase):
80
81  def assertAllEqual(self, a, b, msg=None):
82    if not (isinstance(a, structured_tensor.StructuredTensor) or
83            isinstance(b, structured_tensor.StructuredTensor)):
84      return super(StructuredTensorTest, self).assertAllEqual(a, b, msg)
85    if not isinstance(a, structured_tensor.StructuredTensor):
86      a = structured_tensor.StructuredTensor.from_pyval(a)
87      self._assertStructuredEqual(a, b, msg, False)
88    elif not isinstance(b, structured_tensor.StructuredTensor):
89      b = structured_tensor.StructuredTensor.from_pyval(b)
90      self._assertStructuredEqual(a, b, msg, False)
91    else:
92      self._assertStructuredEqual(a, b, msg, True)
93
94  def _assertStructuredEqual(self, a, b, msg, check_shape):
95    if check_shape:
96      self.assertEqual(repr(a.shape), repr(b.shape))
97    self.assertEqual(set(a.field_names()), set(b.field_names()))
98    for field in a.field_names():
99      a_value = a.field_value(field)
100      b_value = b.field_value(field)
101      self.assertIs(type(a_value), type(b_value))
102      if isinstance(a_value, structured_tensor.StructuredTensor):
103        self._assertStructuredEqual(a_value, b_value, msg, check_shape)
104      else:
105        self.assertAllEqual(a_value, b_value, msg)
106
107  @parameterized.named_parameters([
108      # Scalar (rank=0) StructuredTensors.
109      {
110          "testcase_name": "Rank0_WithTensorFields",
111          "rank": 0,
112          "fields": {"Foo": 5, "Bar": [1, 2, 3]},
113          "expected_shape": []
114      },
115      {
116          "testcase_name": "Rank0_WithRaggedFields",
117          "fields": {
118              # note: fields have varying rank & ragged_rank.
119              "p": ragged_factory_ops.constant_value([[1, 2], [3]]),
120              "q": ragged_factory_ops.constant_value([[[4]], [], [[5, 6]]]),
121              "r": ragged_factory_ops.constant_value([[[4]], [], [[5]]],
122                                                     ragged_rank=1),
123              "s": ragged_factory_ops.constant_value([[[4]], [], [[5]]],
124                                                     ragged_rank=2),
125          },
126          "rank": 0,
127          "expected_shape": [],
128      },
129      {
130          "testcase_name": "Rank0_WithStructuredFields",
131          "fields": lambda: {
132              "foo": StructuredTensor.from_pyval({"a": 1, "b": [1, 2, 3]}),
133              "bar": StructuredTensor.from_pyval(
134                  [[{"x": 12}], [{"x": 13}, {"x": 14}]]),
135              },
136          "rank": 0,
137          "expected_shape": [],
138      },
139      {
140          "testcase_name": "Rank0_WithMixedFields",
141          "fields": lambda: {
142              # TODO(martinz): should handle this, but can't.
143              "f1": 5,
144              "f2": [1, 2, 3],
145              "f3": ragged_factory_ops.constant_value([[1, 2], [3]]),
146              "f4": StructuredTensor.from_pyval({"a": 1, "b": [1, 2, 3]}),
147          },
148          "rank": 0,
149          "expected_shape": [],
150      },
151      # Vector (rank=1) StructuredTensors.
152      {
153          "testcase_name": "Rank1_WithExplicitNrows",
154          "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]},
155          "rank": 1,
156          "expected_shape": [2],
157      },
158      {
159          "testcase_name": "Rank1_WithTensorFields",
160          "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]},
161          "rank": 1,
162          "expected_shape": [2],
163
164      },
165      {
166          "testcase_name": "Rank1_WithRaggedFields",
167          "fields": {
168              # note: fields have varying rank & ragged_rank.
169              "p": ragged_factory_ops.constant_value([[1, 2], [3]]),
170              "q": ragged_factory_ops.constant_value([[[4]], [[5, 6], [7]]]),
171              "r": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]]),
172              "s": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]],
173                                                     ragged_rank=1),
174              "t": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]],
175                                                     ragged_rank=2),
176          },
177          "rank": 1,
178          "expected_shape": [2],
179      },
180      {
181          "testcase_name": "Rank1_WithStructuredFields",
182          "fields": lambda: {
183              "foo": StructuredTensor.from_pyval(
184                  [{"a": 1, "b": [1, 2, 3]}, {"a": 2, "b": []}]),
185              "bar": StructuredTensor.from_pyval(
186                  [[{"x": 12}], [{"x": 13}, {"x": 14}]]),
187          },
188          "rank": 1,
189          "expected_shape": [2],
190      },
191      {
192          "testcase_name": "Rank1_WithMixedFields",
193          "fields": lambda: {
194              "x": [1, 2],
195              "y": [[1, 2], [3, 4]],
196              "r": ragged_factory_ops.constant_value([[1, 2], [3]]),
197              "s": StructuredTensor.from_pyval(
198                  [[{"x": 12}], [{"x": 13}, {"x": 14}]]),
199          },
200          "rank": 1,
201          "expected_shape": [2],
202      },
203      {
204          "testcase_name": "Rank1_WithNoElements",
205          "fields": lambda: {
206              "x": [],
207              "y": np.zeros([0, 8]),
208              "r": ragged_factory_ops.constant([], ragged_rank=1),
209              "s": StructuredTensor.from_pyval([]),
210          },
211          "rank": 1,
212          "expected_shape": [0],  # Note: could also be [None] (?)
213      },
214      {
215          "testcase_name": "Rank1_InferDimSize",
216          "fields": lambda: {
217              "x": [1, 2],
218              "y": [[1, 2], [3, 4]],
219              "r": ragged_factory_ops.constant_value([[1, 2], [3]]),
220              "p": ragged_factory_ops.constant_value([[4], [5, 6, 7]]),
221              "foo": StructuredTensor.from_pyval(
222                  [{"a": 1, "b": [1, 2, 3]}, {"a": 2, "b": []}]),
223              "bar": StructuredTensor.from_pyval(
224                  [[{"x": 12}], [{"x": 13}, {"x": 14}]]),
225          },
226          "rank": 1,
227          "expected_shape": [2],  # inferred from field values.
228      },
229      # Matrix (rank=2) StructuredTensors.
230      {
231          "testcase_name": "Rank2_WithTensorFields",
232          "fields": {
233              "x": [[1, 2, 3], [4, 5, 6]],
234              "y": np.ones([2, 3, 8])
235          },
236          "rank": 2,
237          "expected_shape": [2, 3],  # inferred from field values.
238      },
239      {
240          "testcase_name": "Rank2_WithRaggedFields",
241          "fields": {
242              # Note: fields must have identical row_splits.
243              "a": ragged_factory_ops.constant_value([[1, 2], [3]]),
244              "b": ragged_factory_ops.constant_value([[4, 5], [6]]),
245              "c": ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
246              "d": ragged_factory_ops.constant_value(
247                  [[[[1, 2], [3]], [[4], [], [5]]], [[[6, 7, 8], []]]]),
248          },
249          "rank": 2,
250          "expected_shape": [2, None],
251      },
252      {
253          "testcase_name": "Rank2_WithStructuredFields",
254          "fields": lambda: {
255              # Note: fields must have identical row_splits.
256              "a": StructuredTensor.from_pyval(
257                  [[{"x": 1}], [{"x": 2}, {"x": 3}]]),
258              "b": StructuredTensor.from_pyval(
259                  [[[{"y": 1}]], [[], [{"y": 2}, {"y": 3}]]]),
260          },
261          "rank": 2,
262          "expected_shape": [2, None],  # ragged shape = [[*], [*, *]]
263      },
264      {
265          "testcase_name": "Rank2_WithMixedFields",
266          "fields": lambda: {
267              "a": [[1, 2], [3, 4]],
268              "b": ragged_factory_ops.constant_value([[1, 2], [3, 4]]),
269              "c": StructuredTensor.from_pyval(
270                  [[[{"y": 1}], []], [[], [{"y": 2}, {"y": 3}]]]),
271              "d": ragged_factory_ops.constant_value(
272                  [[[1, 2], []], [[3], [4]]]),
273          },
274          "rank": 2,
275          "expected_shape": [2, 2],
276      },
277      # Rank=4 StructuredTensors.
278      {
279          "testcase_name": "Rank4_WithMixedFields",
280          "fields": lambda: {
281              "a": np.ones([1, 2, 3, 1]),
282              "b": np.ones([1, 2, 3, 1, 5]),
283              "c": ragged_factory_ops.constant(np.zeros([1, 2, 3, 1])),
284              "d": ragged_factory_ops.constant(
285                  np.zeros([1, 2, 3, 1, 3]).tolist(), ragged_rank=1),
286              "e": ragged_factory_ops.constant(
287                  np.zeros([1, 2, 3, 1, 2, 2]).tolist(), ragged_rank=2),
288              "f": ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3])),
289              "g": StructuredTensor.from_pyval(
290                  [[[[{"x": j, "y": k}] for k in range(3)]
291                    for j in range(2)]]),
292              "h": StructuredTensor.from_pyval(
293                  [[[[[{"x": j, "y": k, "z": z} for z in range(j)]]
294                     for k in range(3)]
295                    for j in range(2)]]),
296          },
297          "rank": 4,
298          "expected_shape": [1, 2, 3, 1],  # inferred from field values.
299      },
300  ])  # pyformat: disable
301  def testFromFieldsAndRank(self, fields, rank, expected_shape):
302    if callable(fields):
303      fields = fields()  # deferred construction: fields may include tensors.
304
305    struct = StructuredTensor.from_fields_and_rank(fields, rank)
306    self.assertEqual(struct.shape.as_list(), expected_shape)
307
308  @parameterized.named_parameters([
309      {
310          "testcase_name": "NoFields",
311          "rank": 1,
312          "fields": {},
313          "msg": "Must provide at least one field"
314      },
315      {
316          "testcase_name": "IntegerRank",
317          "rank": 0.5,
318          "fields": {
319              "foo": [1]
320          },
321          "msg": "rank must be an integer"
322      },
323      {
324          "testcase_name": "NonNegativeRank",
325          "rank": -1,
326          "fields": {
327              "bar": [1, 2, 3]
328          },
329          "msg": "rank must be nonnegative"
330      },
331  ])
332  def testFromFieldsAndRankError(self, fields, rank, msg):
333    if callable(fields):
334      fields = fields()  # deferred construction: fields may include tensors.
335    with self.assertRaisesRegex(ValueError, msg):
336      StructuredTensor.from_fields_and_rank(fields, rank)
337
338  @parameterized.named_parameters([
339      # Scalar (rank=0) StructuredTensors.
340      {
341          "testcase_name": "Rank0_WithNoFields",
342          "shape": [],
343          "fields": {},
344      },
345      {
346          "testcase_name": "Rank0_WithTensorFields",
347          "shape": [],
348          "fields": {"Foo": 5, "Bar": [1, 2, 3]},
349      },
350      {
351          "testcase_name": "Rank0_WithRaggedFields",
352          "shape": [],
353          "fields": {
354              # note: fields have varying rank & ragged_rank.
355              "p": ragged_factory_ops.constant_value([[1, 2], [3]]),
356              "q": ragged_factory_ops.constant_value([[[4]], [], [[5, 6]]]),
357              "r": ragged_factory_ops.constant_value([[[4]], [], [[5]]],
358                                                     ragged_rank=1),
359              "s": ragged_factory_ops.constant_value([[[4]], [], [[5]]],
360                                                     ragged_rank=2),
361          },
362      },
363      {
364          "testcase_name": "Rank0_WithStructuredFields",
365          "shape": [],
366          "fields": lambda: {
367              "foo": StructuredTensor.from_pyval({"a": 1, "b": [1, 2, 3]}),
368              "bar": StructuredTensor.from_pyval(
369                  [[{"x": 12}], [{"x": 13}, {"x": 14}]]),
370              },
371      },
372      {
373          "testcase_name": "Rank0_WithMixedFields",
374          "shape": [],
375          "fields": lambda: {
376              "f1": 5,
377              "f2": [1, 2, 3],
378              "f3": ragged_factory_ops.constant_value([[1, 2], [3]]),
379              "f4": StructuredTensor.from_pyval({"a": 1, "b": [1, 2, 3]}),
380          },
381      },
382      # Vector (rank=1) StructuredTensors.
383      {
384          "testcase_name": "Rank1_WithNoFields",
385          "shape": [2],
386          "fields": {},
387      },
388      {
389          "testcase_name": "Rank1_WithExplicitNrows",
390          "shape": [None],
391          "nrows": 2,
392          "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]},
393          "expected_shape": [2],
394      },
395      {
396          "testcase_name": "Rank1_WithTensorFields",
397          "shape": [2],
398          "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]},
399      },
400      {
401          "testcase_name": "Rank1_WithRaggedFields",
402          "shape": [2],
403          "fields": {
404              # note: fields have varying rank & ragged_rank.
405              "p": ragged_factory_ops.constant_value([[1, 2], [3]]),
406              "q": ragged_factory_ops.constant_value([[[4]], [[5, 6], [7]]]),
407              "r": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]]),
408              "s": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]],
409                                                     ragged_rank=1),
410              "t": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]],
411                                                     ragged_rank=2),
412          },
413      },
414      {
415          "testcase_name": "Rank1_WithStructuredFields",
416          "shape": [2],
417          "fields": lambda: {
418              "foo": StructuredTensor.from_pyval(
419                  [{"a": 1, "b": [1, 2, 3]}, {"a": 2, "b": []}]),
420              "bar": StructuredTensor.from_pyval(
421                  [[{"x": 12}], [{"x": 13}, {"x": 14}]]),
422          },
423      },
424      {
425          "testcase_name": "Rank1_WithMixedFields",
426          "shape": [2],
427          "fields": lambda: {
428              "x": [1, 2],
429              "y": [[1, 2], [3, 4]],
430              "r": ragged_factory_ops.constant_value([[1, 2], [3]]),
431              "s": StructuredTensor.from_pyval(
432                  [[{"x": 12}], [{"x": 13}, {"x": 14}]]),
433          },
434      },
435      {
436          "testcase_name": "Rank1_WithNoElements",
437          "shape": [0],
438          "fields": lambda: {
439              "x": [],
440              "y": np.zeros([0, 8]),
441              "r": ragged_factory_ops.constant([], ragged_rank=1),
442              "s": StructuredTensor.from_pyval([]),
443          },
444      },
445      {
446          "testcase_name": "Rank1_InferDimSize",
447          "shape": [None],
448          "fields": lambda: {
449              "x": [1, 2],
450              "y": [[1, 2], [3, 4]],
451              "r": ragged_factory_ops.constant_value([[1, 2], [3]]),
452              "p": ragged_factory_ops.constant_value([[4], [5, 6, 7]]),
453              "foo": StructuredTensor.from_pyval(
454                  [{"a": 1, "b": [1, 2, 3]}, {"a": 2, "b": []}]),
455              "bar": StructuredTensor.from_pyval(
456                  [[{"x": 12}], [{"x": 13}, {"x": 14}]]),
457          },
458          "expected_shape": [2],  # inferred from field values.
459      },
460      # Matrix (rank=2) StructuredTensors.
461      {
462          "testcase_name": "Rank2_WithNoFields",
463          "shape": [2, 8],
464          "fields": {},
465      },
466      {
467          "testcase_name": "Rank2_WithNoFieldsAndExplicitRowPartitions",
468          "shape": [2, None],
469          "row_partitions":
470              lambda: [row_partition.RowPartition.from_row_lengths([3, 7])],
471          "fields": {},
472      },
473      {
474          "testcase_name": "Rank2_WithTensorFields",
475          "shape": [None, None],
476          "fields": {
477              "x": [[1, 2, 3], [4, 5, 6]],
478              "y": np.ones([2, 3, 8])
479          },
480          "expected_shape": [2, 3],  # inferred from field values.
481      },
482      {
483          "testcase_name": "Rank2_WithRaggedFields",
484          "shape": [2, None],  # ragged shape = [[*, *], [*]]
485          "fields": {
486              # Note: fields must have identical row_splits.
487              "a": ragged_factory_ops.constant_value([[1, 2], [3]]),
488              "b": ragged_factory_ops.constant_value([[4, 5], [6]]),
489              "c": ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
490              "d": ragged_factory_ops.constant_value(
491                  [[[[1, 2], [3]], [[4], [], [5]]], [[[6, 7, 8], []]]]),
492          },
493      },
494      {
495          "testcase_name": "Rank2_WithStructuredFields",
496          "shape": [2, None],  # ragged shape = [[*], [*, *]]
497          "fields": lambda: {
498              # Note: fields must have identical row_splits.
499              "a": StructuredTensor.from_pyval(
500                  [[{"x": 1}], [{"x": 2}, {"x": 3}]]),
501              "b": StructuredTensor.from_pyval(
502                  [[[{"y": 1}]], [[], [{"y": 2}, {"y": 3}]]]),
503          },
504      },
505      {
506          "testcase_name": "Rank2_WithMixedFields",
507          "shape": [2, None],
508          "fields": lambda: {
509              "a": [[1, 2], [3, 4]],
510              "b": ragged_factory_ops.constant_value([[1, 2], [3, 4]]),
511              "c": StructuredTensor.from_pyval(
512                  [[[{"y": 1}], []], [[], [{"y": 2}, {"y": 3}]]]),
513              "d": ragged_factory_ops.constant_value(
514                  [[[1, 2], []], [[3], [4]]]),
515          },
516          "expected_shape": [2, 2],
517      },
518      # Rank=4 StructuredTensors.
519      {
520          "testcase_name": "Rank4_WithNoFields",
521          "shape": [1, None, None, 3],
522          "fields": {},
523          "row_partitions": lambda: [
524              row_partition.RowPartition.from_row_lengths([3]),
525              row_partition.RowPartition.from_row_lengths([2, 0, 1]),
526              row_partition.RowPartition.from_uniform_row_length(3, nvals=9)
527          ]
528      },
529      {
530          "testcase_name": "Rank4_WithMixedFields",
531          "shape": [1, None, None, 1],
532          "fields": lambda: {
533              "a": np.ones([1, 2, 3, 1]),
534              "b": np.ones([1, 2, 3, 1, 5]),
535              "c": ragged_factory_ops.constant(np.zeros([1, 2, 3, 1])),
536              "d": ragged_factory_ops.constant(
537                  np.zeros([1, 2, 3, 1, 3]).tolist(), ragged_rank=1),
538              "e": ragged_factory_ops.constant(
539                  np.zeros([1, 2, 3, 1, 2, 2]).tolist(), ragged_rank=2),
540              "f": ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3])),
541              "g": StructuredTensor.from_pyval(
542                  [[[[{"x": j, "y": k}] for k in range(3)]
543                    for j in range(2)]]),
544              "h": StructuredTensor.from_pyval(
545                  [[[[[{"x": j, "y": k, "z": z} for z in range(j)]]
546                     for k in range(3)]
547                    for j in range(2)]]),
548          },
549          "expected_shape": [1, 2, 3, 1],  # inferred from field values.
550      },
551  ])  # pyformat: disable
552  def testFromFields(self,
553                     shape,
554                     fields,
555                     expected_shape=None,
556                     nrows=None,
557                     row_partitions=None):
558    if callable(fields):
559      fields = fields()  # deferred construction: fields may include tensors.
560    if callable(nrows):
561      nrows = nrows()  # deferred construction.
562    if callable(row_partitions):
563      row_partitions = row_partitions()  # deferred construction.
564    for validate in (True, False):
565      struct = StructuredTensor.from_fields(
566          fields,
567          shape,
568          nrows=nrows,
569          row_partitions=row_partitions,
570          validate=validate)
571      if expected_shape is None:
572        expected_shape = shape
573      self.assertEqual(struct.shape.as_list(), expected_shape)
574      self.assertLen(expected_shape, struct.rank)
575      self.assertCountEqual(struct.field_names(), tuple(fields.keys()))
576      for field, value in fields.items():
577        self.assertIsInstance(
578            struct.field_value(field),
579            (ops.Tensor, structured_tensor.StructuredTensor,
580             ragged_tensor.RaggedTensor))
581        self.assertAllEqual(struct.field_value(field), value)
582
583  @parameterized.parameters([
584      dict(fields={}, shape=object(), err=TypeError),
585      dict(
586          fields=object(),
587          shape=[],
588          err=TypeError,
589          msg="fields must be a dictionary"),
590      dict(
591          fields={1: 2}, shape=[], err=TypeError,
592          msg="Unexpected type for key"),
593      dict(
594          fields={"x": object()},
595          shape=[],
596          err=(TypeError, ValueError),
597          msg="Error with shape of x|Unexpected type for value"),
598      dict(
599          fields={},
600          shape=None,
601          err=ValueError,
602          msg="StructuredTensor's shape must have known rank"),
603      dict(
604          fields={"f": 5},
605          shape=[5],
606          err=ValueError,
607          msg=r"Field f has shape \(\), which is incompatible with the shape "
608          r"that was specified or inferred from other fields: \(5,\)|Shapes"),
609      dict(
610          fields=dict(x=[1], y=[]),
611          shape=[None],
612          err=ValueError,
613          msg=r"Error in shape of y"),
614      dict(
615          fields={"": 5},
616          shape=[],
617          err=ValueError,
618          msg="Field name '' is not currently allowed."),
619      dict(
620          fields={"_": 5},
621          shape=[],
622          err=ValueError,
623          msg="Field name '_' is not currently allowed."),
624      dict(
625          fields={
626              "r1": ragged_factory_ops.constant_value([[1, 2], [3]]),
627              "r2": ragged_factory_ops.constant_value([[1, 2, 3], [4]])
628          },
629          shape=[2, None],
630          validate=True,
631          err=ValueError,
632          msg=r"Error in shape of r2",
633      ),
634      dict(
635          fields={},
636          shape=(),
637          nrows=5,
638          err=ValueError,
639          msg="nrows must be None if shape.rank==0"),
640      dict(
641          fields={},
642          shape=(),
643          row_partitions=[0],
644          err=ValueError,
645          msg=r"row_partitions must be None or \[\] if shape.rank<2"),
646      dict(
647          fields={},
648          shape=(None, None, None),
649          row_partitions=[],
650          err=ValueError,
651          msg=r"len\(row_partitions\) must be shape.rank-1"),
652      dict(
653          fields={},
654          shape=[None],
655          err=ValueError,
656          msg="Must specify `nrows`, a fully specified `shape`, "
657          "or have `fields` if `rank=1`"),
658      dict(
659          fields={},
660          shape=[None, None],
661          err=ValueError,
662          msg="Must specify row_partitions, a fully specified shape, "
663          "or have fields if rank > 1"),
664      dict(
665          fields={},
666          shape=[None, None],
667          nrows=lambda: constant_op.constant(2, dtypes.int32),
668          row_partitions=lambda:
669          [row_partition.RowPartition.from_row_lengths([3, 4])],
670          err=ValueError,
671          msg="row_partition dtypes are inconsistent"),
672      dict(
673          fields=lambda: {
674              "a":
675                  ragged_factory_ops.constant([[1]],
676                                              row_splits_dtype=dtypes.int32),
677              "b":
678                  ragged_factory_ops.constant([[1]],
679                                              row_splits_dtype=dtypes.int64)
680          },
681          shape=[None, None],
682          err=ValueError,
683          msg="field values have incompatible row_partition dtypes"),
684  ])
685  def testFromFieldsErrors(self,
686                           fields,
687                           shape,
688                           nrows=None,
689                           row_partitions=None,
690                           validate=False,
691                           err=ValueError,
692                           msg=None,
693                           test_in_eager=True):
694    if not test_in_eager and context.executing_eagerly():
695      return
696    if callable(fields):
697      fields = fields()  # deferred construction.
698    if callable(nrows):
699      nrows = nrows()  # deferred construction.
700    if callable(row_partitions):
701      row_partitions = row_partitions()  # deferred construction.
702    with self.assertRaisesRegex(err, msg):
703      struct = StructuredTensor.from_fields(
704          fields=fields,
705          shape=shape,
706          nrows=nrows,
707          row_partitions=row_partitions,
708          validate=validate)
709      for field_name in struct.field_names():
710        self.evaluate(struct.field_value(field_name))
711      self.evaluate(struct.nrows())
712
713  def testMergeNrowsErrors(self):
714    nrows = constant_op.constant(5)
715    static_nrows = tensor_shape.Dimension(5)
716    value = constant_op.constant([1, 2, 3])
717    with self.assertRaisesRegex(ValueError, "fields have incompatible nrows"):
718      structured_tensor._merge_nrows(
719          nrows, static_nrows, value, dtypes.int32, validate=False)
720
721  def testNestedStructConstruction(self):
722    rt = ragged_factory_ops.constant([[1, 2], [3]])
723    struct1 = StructuredTensor.from_fields(shape=[], fields={"x": [1, 2]})
724    struct2 = StructuredTensor.from_fields(shape=[2], fields={"x": [1, 2]})
725    struct3 = StructuredTensor.from_fields(
726        shape=[], fields={
727            "r": rt,
728            "s": struct1
729        })
730    struct4 = StructuredTensor.from_fields(
731        shape=[2], fields={
732            "r": rt,
733            "s": struct2
734        })
735
736    self.assertEqual(struct3.shape.as_list(), [])
737    self.assertEqual(struct3.rank, 0)
738    self.assertEqual(set(struct3.field_names()), set(["r", "s"]))
739    self.assertAllEqual(struct3.field_value("r"), rt)
740    self.assertAllEqual(struct3.field_value("s"), struct1)
741
742    self.assertEqual(struct4.shape.as_list(), [2])
743    self.assertEqual(struct4.rank, 1)
744    self.assertEqual(set(struct4.field_names()), set(["r", "s"]))
745    self.assertAllEqual(struct4.field_value("r"), rt)
746    self.assertAllEqual(struct4.field_value("s"), struct2)
747
748  def testPartitionOuterDims(self):
749    a = dict(x=1, y=[1, 2])
750    b = dict(x=2, y=[3, 4])
751    c = dict(x=3, y=[5, 6])
752    d = dict(x=4, y=[7, 8])
753    st1 = StructuredTensor.from_pyval([a, b, c, d])
754
755    st2 = st1.partition_outer_dimension(
756        row_partition.RowPartition.from_row_splits([0, 2, 2, 3, 4]))
757    self.assertAllEqual(st2, [[a, b], [], [c], [d]])
758
759    st3 = st2.partition_outer_dimension(
760        row_partition.RowPartition.from_row_lengths([1, 0, 3, 0]))
761    self.assertAllEqual(st3, [[[a, b]], [], [[], [c], [d]], []])
762
763    # If we partition with uniform_row_lengths, then `x` is partitioned into
764    # a Tensor (not a RaggedTensor).
765    st4 = st1.partition_outer_dimension(
766        row_partition.RowPartition.from_uniform_row_length(
767            uniform_row_length=2, nvals=4, nrows=2))
768    self.assertAllEqual(
769        st4,
770        structured_tensor.StructuredTensor.from_pyval(
771            [[a, b], [c, d]],
772            structured_tensor.StructuredTensor.Spec(
773                _ragged_shape=DynamicRaggedShape.Spec(
774                    row_partitions=[],
775                    static_inner_shape=[2, 2],
776                    dtype=dtypes.int64),
777                _fields={
778                    "x":
779                        tensor_spec.TensorSpec([2, 2], dtypes.int32),
780                    "y":
781                        ragged_tensor.RaggedTensorSpec([2, 2, None],
782                                                       dtypes.int32)
783                })))
784
785  def testPartitionOuterDimension3(self):
786    rt = ragged_tensor.RaggedTensor.from_value_rowids(
787        array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
788    struct = structured_tensor.StructuredTensor.from_fields({"r": rt}, [2])
789    struct_2 = struct.partition_outer_dimension(
790        row_partition.RowPartition.from_row_splits([0, 1, 2]))
791    struct_3 = struct_2.partition_outer_dimension(
792        row_partition.RowPartition.from_row_splits([0, 1, 2]))
793    self.assertEqual(3, struct_3.rank)
794
795  def testWithPrivateSpecialType(self):
796    rt = ragged_tensor.RaggedTensor.from_value_rowids(
797        array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
798    pst = _PrivateSpecialType(rt)
799    pst_shape = array_ops.shape_v2(pst)
800    st = structured_tensor.StructuredTensor.from_fields_and_rank({"r": pst}, 1)
801    st_shape = st._ragged_shape
802    self.assertEqual(1, st.rank)
803    self.assertAllEqual(pst_shape[0], st_shape[0])
804
805  def testWithPrivateBrokenType(self):
806    rt = ragged_tensor.RaggedTensor.from_value_rowids(
807        array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
808    pbt = _PrivateBrokenType(rt)
809
810    with self.assertRaisesRegex(ValueError, "Error in shape of r"):
811      structured_tensor.StructuredTensor.from_fields_and_rank({"r": pbt}, 1)
812
813  def testPartitionOuterDimsErrors(self):
814    st = StructuredTensor.from_fields({})
815    partition = row_partition.RowPartition.from_row_splits([0])
816    with self.assertRaisesRegex(ValueError,
817                                r"Shape \(\) must have rank at least 1"):
818      st.partition_outer_dimension(partition)
819
820    with self.assertRaisesRegex(TypeError,
821                                "row_partition must be a RowPartition"):
822      st.partition_outer_dimension(10)
823
824  @parameterized.named_parameters([
825      {
826          "testcase_name": "ScalarEmpty",
827          "pyval": {},
828          "expected": lambda: StructuredTensor.from_fields(shape=[], fields={})
829      },
830      {
831          "testcase_name": "ScalarSimple",
832          "pyval": {"a": 12, "b": [1, 2, 3], "c": [[1, 2], [3]]},
833          "expected": lambda: StructuredTensor.from_fields(shape=[], fields={
834              "a": 12,
835              "b": [1, 2, 3],
836              "c": ragged_factory_ops.constant([[1, 2], [3]])})
837      },
838      {
839          "testcase_name": "ScalarSimpleWithTypeSpec",
840          "pyval": {"a": 12, "b": [1, 2, 3], "c": [[1, 2], [3]]},
841          "type_spec": StructuredTensor.Spec._from_fields_and_rank(
842              fields={
843                  "a": tensor_spec.TensorSpec([], dtypes.int32),
844                  "b": tensor_spec.TensorSpec([None], dtypes.int32),
845                  "c": ragged_tensor.RaggedTensorSpec([None, None],
846                                                      dtypes.int32)},
847              rank=0),
848          "expected": lambda: StructuredTensor.from_fields(shape=[], fields={
849              "a": 12,
850              "b": [1, 2, 3],
851              "c": ragged_factory_ops.constant([[1, 2], [3]])})
852      },
853      {
854          "testcase_name": "ScalarWithNestedStruct",
855          "pyval": {"a": 12, "b": [1, 2, 3], "c": {"x": b"Z", "y": [10, 20]}},
856          "expected": lambda: StructuredTensor.from_fields(shape=[], fields={
857              "a": 12,
858              "b": [1, 2, 3],
859              "c": StructuredTensor.from_fields(shape=[], fields={
860                  "x": "Z",
861                  "y": [10, 20]})})
862      },
863      {
864          "testcase_name": "EmptyList",
865          "pyval": [],
866          "expected": lambda: [],
867      },
868      {
869          "testcase_name": "ListOfEmptyList",
870          "pyval": [[], []],
871          "expected": lambda: [[], []],
872      },
873      {
874          "testcase_name": "EmptyListWithTypeSpecAndFields",
875          "pyval": [],
876          "type_spec": structured_tensor.StructuredTensor.Spec._from_fields_and_rank(
877              fields={"a": tensor_spec.TensorSpec([0], dtypes.int32)},
878              rank=1),
879          "expected": lambda: StructuredTensor.from_fields(shape=[0], fields={
880              "a": []})
881      },
882      {
883          "testcase_name": "EmptyListWithTypeSpecNoFieldsShape0_5",
884          "pyval": [],
885          "type_spec": StructuredTensor.Spec._from_shape(DynamicRaggedShape.Spec(
886              row_partitions=[],
887              static_inner_shape=[0, 5],
888              dtype=dtypes.int64)),
889          "expected": lambda: StructuredTensor.from_fields(shape=[0, 5],
890                                                           fields={})
891      },
892      {
893          "testcase_name": "EmptyListWithTypeSpecNoFieldsShape1_0",
894          "pyval": [[]],
895          "type_spec": StructuredTensor.Spec._from_shape(
896              DynamicRaggedShape.Spec(
897                  row_partitions=[],
898                  static_inner_shape=[1, 0],
899                  dtype=dtypes.int64)),
900          "expected": lambda: StructuredTensor.from_shape(
901              DynamicRaggedShape.from_lengths([1, 0]))
902      },
903      {
904          "testcase_name": "VectorOfDict",
905          "pyval": [{"a": 1}, {"a": 2}],
906          "expected": lambda: StructuredTensor.from_fields(shape=[2], fields={
907              "a": [1, 2]})
908      },
909      {
910          "testcase_name": "VectorOfDictWithNestedStructScalar",
911          "pyval": [{"a": 1, "b": {"x": [1, 2]}},
912                    {"a": 2, "b": {"x": [3]}}],
913          "expected": lambda: StructuredTensor.from_fields(shape=[2], fields={
914              "a": [1, 2],
915              "b": StructuredTensor.from_fields(shape=[2], fields={
916                  "x": ragged_factory_ops.constant([[1, 2], [3]])})}),
917      },
918      {
919          "testcase_name": "VectorOfDictWithNestedStructVector",
920          "pyval": [{"a": 1, "b": [{"x": [1, 2]}, {"x": [5]}]},
921                    {"a": 2, "b": [{"x": [3]}]}],
922          "expected": lambda: StructuredTensor.from_fields(shape=[2], fields={
923              "a": [1, 2],
924              "b": StructuredTensor.from_fields(shape=[2, None], fields={
925                  "x": ragged_factory_ops.constant([[[1, 2], [5]], [[3]]])})}),
926      },
927      {
928          "testcase_name": "Ragged2DOfDict",
929          "pyval": [[{"a": 1}, {"a": 2}, {"a": 3},],
930                    [{"a": 4}, {"a": 5}]],
931          "expected": lambda: StructuredTensor.from_fields(
932              shape=[2, None],
933              fields={
934                  "a": ragged_factory_ops.constant([[1, 2, 3], [4, 5]])})
935      },
936      {
937          # With no type-spec, all tensors>1D are encoded as ragged:
938          "testcase_name": "MatrixOfDictWithoutTypeSpec",
939          "pyval": [[{"a": 1}, {"a": 2}, {"a": 3},],
940                    [{"a": 4}, {"a": 5}, {"a": 6}]],
941          "expected": lambda: StructuredTensor.from_fields(
942              shape=[2, None], fields={
943                  "a": ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])})
944      },
945      {
946          # TypeSpec can be used to specify StructuredTensor shape.
947          "testcase_name": "MatrixOfDictWithTypeSpec",
948          "pyval": [[{"a": 1}, {"a": 2}, {"a": 3},],
949                    [{"a": 4}, {"a": 5}, {"a": 6}]],
950          "type_spec": structured_tensor.StructuredTensorSpec([2, 3], {
951              "a": tensor_spec.TensorSpec(None, dtypes.int32)}),
952          "expected": lambda: StructuredTensor.from_fields(
953              shape=[2, 3], fields={"a": [[1, 2, 3], [4, 5, 6]]})
954      },
955  ])  # pyformat: disable
956  def testPyvalConversion(self, pyval, expected, type_spec=None):
957    expected = expected()  # Deferred init because it creates tensors.
958    actual = structured_tensor.StructuredTensor.from_pyval(pyval, type_spec)
959    self.assertAllEqual(actual, expected)
960    if isinstance(actual, structured_tensor.StructuredTensor):
961      if context.executing_eagerly():  # to_pyval only available in eager.
962        self.assertEqual(actual.to_pyval(), pyval)
963
964  def testStructuredTensorSpecFactory(self):
965    spec = StructuredTensor.Spec._from_fields_and_rank(
966        fields={
967            "a": tensor_spec.TensorSpec([], dtypes.int32),
968            "b": tensor_spec.TensorSpec([None], dtypes.int32),
969            "c": ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)},
970        rank=0)
971    self.assertEqual(spec.rank, 0)
972
973  @parameterized.named_parameters([
974      dict(
975          testcase_name="NoFieldsRaggedRank0",
976          st=lambda: StructuredTensor.from_fields({}, (3,)),
977          expected=[{}, {}, {}]),
978      dict(
979          testcase_name="NoFieldsRaggedRank1",
980          st=lambda: StructuredTensor.from_fields(
981              {}, (2, None),
982              row_partitions=[
983                  row_partition.RowPartition.from_row_lengths([3, 2])]),
984          expected=[[{}, {}, {}], [{}, {}]]),
985      dict(
986          testcase_name="NoFieldsRaggedRank2",
987          st=lambda: StructuredTensor.from_fields(
988              {}, (2, None, None),
989              row_partitions=[
990                  row_partition.RowPartition.from_row_lengths([2, 1]),
991                  row_partition.RowPartition.from_row_lengths([2, 3, 1])]),
992          expected=[[[{}, {}], [{}, {}, {}]], [[{}]]]),
993      dict(
994          testcase_name="NoFieldsRaggedRank2NoDicts",
995          st=lambda: StructuredTensor.from_fields(
996              {}, (1, None, None),
997              row_partitions=[
998                  row_partition.RowPartition.from_row_lengths([2]),
999                  row_partition.RowPartition.from_row_lengths([0, 0])]),
1000          expected=[[[], []]]),
1001      dict(
1002          testcase_name="NestedStructTensorWithNoFields",
1003          st=lambda: StructuredTensor.from_fields(
1004              {
1005                  "foo": ragged_factory_ops.constant([[[], []]]),
1006                  "bar": StructuredTensor.from_fields(
1007                      {}, (1, None, None, None), row_partitions=[
1008                          row_partition.RowPartition.from_row_lengths([2]),
1009                          row_partition.RowPartition.from_row_lengths([0, 0]),
1010                          row_partition.RowPartition.from_row_lengths([]),
1011                      ])
1012
1013              }, (1, None, None),),
1014          expected=[[[], []]]),
1015  ])  # pyformat: disable
1016  def testToPyval(self, st, expected):
1017    if context.executing_eagerly():  # to_pyval only available in eager.
1018      st = st()  # Deferred init because it creates tensors.
1019      self.assertEqual(st.to_pyval(), expected)
1020
1021  @parameterized.named_parameters([
1022      dict(testcase_name="MissingKeys",
1023           pyval=[{"a": [1, 2]}, {"b": [3, 4]}],
1024           err=KeyError,
1025           msg="'b'"),
1026      dict(testcase_name="TypeSpecMismatch_DictKey",
1027           pyval={"a": 1},
1028           type_spec=StructuredTensor.Spec._from_fields_and_rank(
1029               fields={"b": tensor_spec.TensorSpec([1], dtypes.int32)},
1030               rank=1),
1031           msg=r"Value at \(\) does not match typespec"),
1032      dict(testcase_name="TypeSpecMismatch_ListDictKey",
1033           pyval=[{"a": 1}],
1034           type_spec=StructuredTensor.Spec._from_fields_and_rank(
1035               fields={"b": tensor_spec.TensorSpec([1], dtypes.int32)},
1036               rank=1),
1037           msg=r"Value at \(\) does not match typespec"),
1038      dict(testcase_name="TypeSpecMismatch_RankMismatch",
1039           pyval=[{"a": 1}],
1040           type_spec=StructuredTensor.Spec._from_fields_and_rank(
1041               fields={"a": tensor_spec.TensorSpec([], dtypes.int32)},
1042               rank=0),
1043           msg=r"Value at \(\) does not match typespec \(rank mismatch\)"),
1044      dict(testcase_name="TypeSpecMismatch_Scalar",
1045           pyval=0,
1046           type_spec=StructuredTensor.Spec._from_shape(
1047               DynamicRaggedShape.Spec(
1048                   row_partitions=[],
1049                   static_inner_shape=[],
1050                   dtype=dtypes.int64)),
1051           msg=r"Value at \(\) does not match typespec"),
1052      dict(testcase_name="TypeSpecMismatch_ListTensor",
1053           pyval={"a": [[1]]},
1054           type_spec=StructuredTensor.Spec._from_fields_and_rank(
1055               fields={"a": tensor_spec.TensorSpec([], dtypes.int32)},
1056               rank=0),
1057           msg=r"Value at \('a',\) does not match typespec"),
1058      dict(testcase_name="TypeSpecMismatch_ListTensorDeep",
1059           pyval={"a": {"b": [[1]]}},
1060           type_spec=StructuredTensor.Spec._from_fields_and_rank(
1061               fields={"a": StructuredTensor.Spec._from_fields_and_rank(
1062                   fields={"b": tensor_spec.TensorSpec([], dtypes.int32)},
1063                   rank=0
1064               )},
1065               rank=0),
1066           msg=r"Value at \('a', 'b'\) does not match typespec"),
1067      dict(testcase_name="TypeSpecMismatch_ListTensorDeep_infer",
1068           pyval={"a": [{"b": [[1]]}, {"b": [["c"]]}]},
1069           type_spec=None,
1070           msg=r"Error parsing path \('a', 'b'\)"),
1071      dict(testcase_name="TypeSpecMismatch_ListTensorDeep_infer2",
1072           pyval=[{"a": 1}, {"a": "c"}],
1073           type_spec=None,
1074           msg=r"Error parsing path \('a',\)"),
1075      dict(testcase_name="TypeSpecMismatch_ListSparse",
1076           pyval=[1, 2],
1077           type_spec=sparse_tensor.SparseTensorSpec([None], dtypes.int32),
1078           msg=r"Value at \(\) does not match typespec"),
1079      dict(testcase_name="TypeSpecMismatch_ListStruct",
1080           pyval=[[1]],
1081           type_spec=StructuredTensor.Spec._from_fields_and_rank(
1082               fields={"a": tensor_spec.TensorSpec([1, 1], dtypes.int32)},
1083               rank=2),
1084           msg=r"Value at \(\) does not match typespec"),
1085      dict(testcase_name="InconsistentDictionaryDepth",
1086           pyval=[{}, [{}]],
1087           msg="Inconsistent depth of dictionaries"),
1088      dict(testcase_name="FOO",
1089           pyval=[[{}], 5],
1090           msg="Expected dict or nested list/tuple of dict"),
1091
1092  ])  # pyformat: disable
1093  def testFromPyvalError(self, pyval, err=ValueError, type_spec=None, msg=None):
1094    with self.assertRaisesRegex(err, msg):
1095      structured_tensor.StructuredTensor.from_pyval(pyval, type_spec)
1096
1097  def testToPyvalRequiresEagerMode(self):
1098    st = structured_tensor.StructuredTensor.from_pyval({"a": 5})
1099    if not context.executing_eagerly():
1100      with self.assertRaisesRegex(ValueError, "only supported in eager mode."):
1101        st.to_pyval()
1102
1103  @parameterized.named_parameters([
1104      (
1105          "Rank0",
1106          [],
1107      ),
1108      (
1109          "Rank1",
1110          [5, 3],
1111      ),
1112      (
1113          "Rank2",
1114          [5, 8, 3],
1115      ),
1116      (
1117          "Rank5",
1118          [1, 2, 3, 4, 5],
1119      ),
1120  ])
1121  def testRowPartitionsFromUniformShape(self, shape):
1122    for rank in range(len(shape)):
1123      partitions = structured_tensor._row_partitions_for_uniform_shape(
1124          ops.convert_to_tensor(shape), rank)
1125      self.assertLen(partitions, max(0, rank - 1))
1126      if partitions:
1127        self.assertAllEqual(shape[0], partitions[0].nrows())
1128      for (dim, partition) in enumerate(partitions):
1129        self.assertAllEqual(shape[dim + 1], partition.uniform_row_length())
1130
1131  @parameterized.named_parameters([
1132      # For shapes: U = uniform dimension; R = ragged dimension.
1133      dict(
1134          testcase_name="Shape_UR_Rank2",
1135          rt=[[1, 2], [], [3]],
1136          rt_ragged_rank=1,
1137          rank=2,
1138          expected_row_lengths=[[2, 0, 1]]),
1139      dict(
1140          testcase_name="Shape_URR_Rank2",
1141          rt=[[[1, 2], []], [[3]]],
1142          rt_ragged_rank=2,
1143          rank=2,
1144          expected_row_lengths=[[2, 1]]),
1145      dict(
1146          testcase_name="Shape_URU_Rank2",
1147          rt=[[[1], [2]], [[3]]],
1148          rt_ragged_rank=1,
1149          rank=2,
1150          expected_row_lengths=[[2, 1]]),
1151      dict(
1152          testcase_name="Shape_URR_Rank3",
1153          rt=[[[1, 2], []], [[3]]],
1154          rt_ragged_rank=2,
1155          rank=3,
1156          expected_row_lengths=[[2, 1], [2, 0, 1]]),
1157      dict(
1158          testcase_name="Shape_URU_Rank3",
1159          rt=[[[1], [2]], [[3]]],
1160          rt_ragged_rank=1,
1161          rank=3,
1162          expected_row_lengths=[[2, 1], [1, 1, 1]]),
1163      dict(
1164          testcase_name="Shape_URRUU_Rank2",
1165          rt=[[[[[1, 2]]]]],
1166          rt_ragged_rank=2,
1167          rank=2,
1168          expected_row_lengths=[[1]]),
1169      dict(
1170          testcase_name="Shape_URRUU_Rank3",
1171          rt=[[[[[1, 2]]]]],
1172          rt_ragged_rank=2,
1173          rank=3,
1174          expected_row_lengths=[[1], [1]]),
1175      dict(
1176          testcase_name="Shape_URRUU_Rank4",
1177          rt=[[[[[1, 2]]]]],
1178          rt_ragged_rank=2,
1179          rank=4,
1180          expected_row_lengths=[[1], [1], [1]]),
1181      dict(
1182          testcase_name="Shape_URRUU_Rank5",
1183          rt=[[[[[1, 2]]]]],
1184          rt_ragged_rank=2,
1185          rank=5,
1186          expected_row_lengths=[[1], [1], [1], [2]]),
1187  ])
1188  def testRowPartitionsForRaggedTensor(self, rt, rt_ragged_rank, rank,
1189                                       expected_row_lengths):
1190    rt = ragged_factory_ops.constant(rt, rt_ragged_rank)
1191    partitions = structured_tensor._row_partitions_for_ragged_tensor(
1192        rt, rank, dtypes.int64)
1193    self.assertLen(partitions, rank - 1)
1194    self.assertLen(partitions, len(expected_row_lengths))
1195    for partition, expected in zip(partitions, expected_row_lengths):
1196      self.assertAllEqual(partition.row_lengths(), expected)
1197
1198  @parameterized.named_parameters([
1199      dict(
1200          testcase_name="2D_0_1",
1201          st=[[{"x": 1}, {"x": 2}], [{"x": 3}]],
1202          outer_axis=0, inner_axis=1,
1203          expected=[{"x": 1}, {"x": 2}, {"x": 3}]),
1204      dict(
1205          testcase_name="3D_0_1",
1206          st=[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1207          outer_axis=0, inner_axis=1,
1208          expected=[[{"x": 1}, {"x": 2}], [{"x": 3}], [{"x": 4}]]),
1209      dict(
1210          testcase_name="3D_1_2",
1211          st=[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1212          outer_axis=1, inner_axis=2,
1213          expected=[[{"x": 1}, {"x": 2}, {"x": 3}], [{"x": 4}]]),
1214      dict(
1215          testcase_name="3D_0_2",
1216          st=[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1217          outer_axis=0, inner_axis=2,
1218          expected=[{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]),
1219      dict(
1220          testcase_name="4D_0_1",
1221          st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1222              [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]],
1223          outer_axis=0, inner_axis=1,
1224          expected=[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]],
1225                    [[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]),
1226      dict(
1227          testcase_name="4D_0_2",
1228          st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1229              [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]],
1230          outer_axis=0, inner_axis=2,
1231          expected=[[{"x": 1}, {"x": 2}], [{"x": 3}], [{"x": 4}],
1232                    [{"x": 5}], [{"x": 6}], [{"x": 7}]]),
1233      dict(
1234          testcase_name="4D_0_3",
1235          st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1236              [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]],
1237          outer_axis=0, inner_axis=3,
1238          expected=[{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4},
1239                    {"x": 5}, {"x": 6}, {"x": 7}]),
1240      dict(
1241          testcase_name="4D_1_2",
1242          st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1243              [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]],
1244          outer_axis=1, inner_axis=2,
1245          expected=[[[{"x": 1}, {"x": 2}], [{"x": 3}], [{"x": 4}]],
1246                    [[{"x": 5}], [{"x": 6}], [{"x": 7}]]]),
1247      dict(
1248          testcase_name="4D_1_3",
1249          st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1250              [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]],
1251          outer_axis=1, inner_axis=3,
1252          expected=[[{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}],
1253                    [{"x": 5}, {"x": 6}, {"x": 7}]]),
1254      dict(
1255          testcase_name="4D_2_3",
1256          st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]],
1257              [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]],
1258          outer_axis=2, inner_axis=3,
1259          expected=[[[{"x": 1}, {"x": 2}, {"x": 3}], [{"x": 4}]],
1260                    [[{"x": 5}], [{"x": 6}, {"x": 7}]]]),
1261  ])  # pyformat: disable
1262  def testMergeDims(self, st, outer_axis, inner_axis, expected):
1263    st = StructuredTensor.from_pyval(st)
1264    result = st.merge_dims(outer_axis, inner_axis)
1265    self.assertAllEqual(result, expected)
1266
1267  def testMergeDimsDetail_3D_0_1(self):
1268    st = StructuredTensor.from_pyval(
1269        [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]])
1270    result = st.merge_dims(0, 1)
1271    expected_shape = tensor_shape.TensorShape([3, None])
1272    self.assertTrue(expected_shape.is_compatible_with(result.shape))
1273
1274  def testMergeDims_0_1(self):
1275    rt = ragged_tensor.RaggedTensor.from_value_rowids(
1276        array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1])
1277    struct = StructuredTensor.from_fields({"r": rt}, [2])
1278    struct_2 = struct.partition_outer_dimension(
1279        row_partition.RowPartition.from_row_splits([0, 1, 2]))
1280    struct_3 = struct_2.partition_outer_dimension(
1281        row_partition.RowPartition.from_row_splits([0, 1, 2]))
1282    self.assertLen(struct_3.row_partitions, 2)
1283    merged = struct_3.merge_dims(0, 1)
1284    self.assertLen(merged.row_partitions, 1)
1285
1286  def testMergeDimsError(self):
1287    st = StructuredTensor.from_pyval([[[{"a": 5}]]])
1288    with self.assertRaisesRegex(
1289        ValueError, r"Expected outer_axis \(2\) to be less than "
1290        r"or equal to inner_axis \(1\)"):
1291      st.merge_dims(2, 1)
1292
1293  def testTupleFieldValue(self):
1294    st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
1295    self.assertAllEqual(st.field_value(("a",)), 5)
1296    self.assertAllEqual(st.field_value(("b", "c")), [1, 2, 3])
1297    expected = r"Field path \(.*a.*,.*b.*\) not found in .*"
1298    with self.assertRaisesRegex(KeyError, expected):
1299      st.field_value(("a", "b"))
1300
1301  @parameterized.named_parameters([
1302      dict(
1303          testcase_name="scalar_scalar_scalar",
1304          st={"b": {"a": 5}},
1305          source_path=("b", "a"),
1306          new_field_name="new_field",
1307          expected={"b": {"a": 5}, "new_field": 5},),
1308      dict(
1309          testcase_name="scalar_scalar_repeated",
1310          st={"b": {"a": [5, 3]}},
1311          source_path=("b", "a"),
1312          new_field_name="new_field",
1313          expected={"b": {"a": [5, 3]}, "new_field": [5, 3]}),
1314      dict(
1315          testcase_name="scalar_scalar_repeated2",
1316          st={"b": {"a": [[7], [5, 3]]}},
1317          source_path=("b", "a"),
1318          new_field_name="new_field",
1319          expected={"b": {"a": [[7], [5, 3]]}, "new_field": [[7], [5, 3]]}),
1320      dict(
1321          testcase_name="repeated_scalar_repeated",
1322          st=[{"b": {"a": [7]}},
1323              {"b": {"a": [5, 3]}}],
1324          source_path=("b", "a"),
1325          new_field_name="new_field",
1326          expected=[{"b": {"a": [7]}, "new_field": [7]},
1327                    {"b": {"a": [5, 3]}, "new_field": [5, 3]}]),
1328      dict(
1329          testcase_name="repeated_scalar_repeated2",
1330          st=[{"b": {"a": [[5, 7], []]}},
1331              {"b": {"a": [[5, 1], [3]]}}],
1332          source_path=("b", "a"),
1333          new_field_name="new_field",
1334          expected=[{"b": {"a": [[5, 7], []]},
1335                     "new_field": [[5, 7], []]},
1336                    {"b": {"a": [[5, 1], [3]]},
1337                     "new_field": [[5, 1], [3]]}]),
1338      dict(
1339          testcase_name="scalar_scalar_scalar_scalar",
1340          st={"a": {"b": {"c": 7}}},
1341          source_path=("a", "b", "c"),
1342          new_field_name="new_field",
1343          expected={"a": {"b": {"c": 7}, "new_field": 7}}),
1344      dict(
1345          testcase_name="repeated_scalar_scalar_scalar",
1346          st=[{"a": {"b": {"c": 7}}},
1347              {"a": {"b": {"c": 5}}}],
1348          source_path=("a", "b", "c"),
1349          new_field_name="new_field",
1350          expected=[{"a": {"b": {"c": 7}, "new_field": 7}},
1351                    {"a": {"b": {"c": 5}, "new_field": 5}}],),
1352      dict(
1353          testcase_name="repeated_repeated_scalar_scalar",
1354          st=[{"a": [{"b": {"c": 7}}, {"b": {"c": 3}}]},
1355              {"a": [{"b": {"c": 5}}]}],
1356          source_path=("a", "b", "c"),
1357          new_field_name="new_field",
1358          expected=[{"a": [{"b": {"c": 7}, "new_field": 7},
1359                           {"b": {"c": 3}, "new_field": 3}]},
1360                    {"a": [{"b": {"c": 5}, "new_field": 5}]}]),
1361      dict(
1362          testcase_name="docs_tokens",
1363          st=[{"docs": [{"tokens": [7, 17]}, {"tokens": [3, 13]}]},
1364              {"docs": [{"tokens": [5, 15]}]}],
1365          source_path=("docs", "tokens"),
1366          new_field_name="docs_tokens",
1367          expected=[{"docs": [{"tokens": [7, 17]}, {"tokens": [3, 13]}],
1368                     "docs_tokens": [7, 17, 3, 13]},
1369                    {"docs": [{"tokens": [5, 15]}],
1370                     "docs_tokens": [5, 15]}],
1371          ),
1372      dict(
1373          testcase_name="repeated_repeated_scalar_repeated",
1374          st=[{"a": [{"b": {"c": [7, 17]}}, {"b": {"c": [3, 13]}}]},
1375              {"a": [{"b": {"c": [5, 15]}}]}],
1376          source_path=("a", "b", "c"),
1377          new_field_name="new_field",
1378          expected=[{"a": [{"b": {"c": [7, 17]}, "new_field": [7, 17]},
1379                           {"b": {"c": [3, 13]}, "new_field": [3, 13]}]},
1380                    {"a": [{"b": {"c": [5, 15]}, "new_field": [5, 15]}]}]),
1381      dict(
1382          testcase_name="scalar_scalar_scalar_repeated",
1383          st={"a": {"b": {"c": [7, 3, 5]}}},
1384          source_path=("a", "b", "c"),
1385          new_field_name="new_field",
1386          expected={"a": {"b": {"c": [7, 3, 5]}, "new_field": [7, 3, 5]}}),
1387      dict(
1388          testcase_name="repeated_repeated_scalar_repeated2",
1389          st=[{"a": [{"b": {"c": [[7, 3], [17]]}}, {"b": {"c": [[3, 13]]}}]},
1390              {"a": [{"b": {"c": [[5, 15]]}}]}],
1391          source_path=("a", "b", "c"),
1392          new_field_name="new_field",
1393          expected=[{"a": [{"b": {"c": [[7, 3], [17]]},
1394                            "new_field": [[7, 3], [17]]},
1395                           {"b": {"c": [[3, 13]]},
1396                            "new_field": [[3, 13]]}]},
1397                    {"a": [{"b": {"c": [[5, 15]]},
1398                            "new_field": [[5, 15]]}]}]),
1399      dict(testcase_name="example_4_promote_of_labeled_vector",
1400           st=[{"user_info": [{"gaia_id": {"vec": [0, 1, 2]}}]},
1401               {"user_info": [{"gaia_id": {"vec": [3, 4, 5]}}]}],
1402           source_path=("user_info", "gaia_id"),
1403           new_field_name="user_info_gaia_id",
1404           expected=[{"user_info": [{"gaia_id": {"vec": [0, 1, 2]}}],
1405                      "user_info_gaia_id": [{"vec": [0, 1, 2]}]},
1406                     {"user_info": [{"gaia_id": {"vec": [3, 4, 5]}}],
1407                      "user_info_gaia_id": [{"vec": [3, 4, 5]}]}]),
1408      dict(
1409          testcase_name="promote_structure",
1410          st=[{"a": [{"aa": [{"b": {"c": 1}}, {"b": {"c": 8}}]}],},
1411              {"a": [{"aa": [{"b": {"c": 12}}]}],}],
1412          source_path=("a", "aa", "b"),
1413          new_field_name="new_field",
1414          expected=[{"a": [{"aa": [{"b": {"c": 1}}, {"b": {"c": 8}}],
1415                            "new_field": [{"c": 1}, {"c": 8}]}]},
1416                    {"a": [{"aa": [{"b": {"c": 12}}],
1417                            "new_field": [{"c": 12}]}]}])])  # pyformat: disable
1418  def testPromote(self, st, source_path, new_field_name, expected):
1419    st2 = StructuredTensor.from_pyval(st)
1420    expected2 = StructuredTensor.from_pyval(expected)
1421    result = st2.promote(source_path, new_field_name)
1422    self.assertAllEqual(result, expected2)
1423
1424  def testPromoteDense(self):
1425    st = StructuredTensor.from_fields(
1426        {
1427            "a":
1428                StructuredTensor.from_fields(
1429                    {"b": [[[1, 11], [2, 12]], [[3, 13], [4, 14]]]},
1430                    shape=[2, 2, 2])
1431        },
1432        shape=[2])
1433    result = st.promote(("a", "b"), "new_field")
1434    self.assertEqual(st.rank, 1)
1435    self.assertEqual(st.field_value("a").rank, 3)
1436    self.assertAllEqual(
1437        result.field_value("new_field"), [[1, 11, 2, 12], [3, 13, 4, 14]])
1438
1439  def testMergeDimsGeneric(self):
1440    """This is an example of a dense tensor being merged, when outer=rank.
1441
1442    Note that outer=rank is equivalent to outer=rank - 1. And yet, from the
1443    perspective of promote, it is nice to be able to have this functionality
1444    directly available, because sometimes the rank of the parent equals the
1445    rank of the child.
1446
1447    Finally, note that merge_dims for Ragged and StructuredTensor would not
1448    accept this as a valid argument.
1449
1450    Note: _merge_dims_generic is private, but these unit tests help to
1451    discuss the proper API definition.
1452    """
1453    t = array_ops.constant([[[1, 11], [2, 12]], [[3, 13], [4, 14]]])
1454    t2 = structured_tensor._merge_dims_generic(t, 1, 3)
1455    self.assertAllEqual(t2, [[1, 11, 2, 12], [3, 13, 4, 14]])
1456
1457  def testMergeDimsGenericNoop(self):
1458    """This is an example of a dense tensor being merged, when outer=inner.
1459
1460    Sometimes, when promoting, the parent and grandparent ranks are equal.
1461    Finally, note that merge_dims for Ragged and StructuredTensor would not
1462    accept this as a valid argument. This should be aligned.
1463    """
1464    t = array_ops.constant([[[1, 11], [2, 12]], [[3, 13], [4, 14]]])
1465    t2 = structured_tensor._merge_dims_generic(t, 2, 2)
1466    self.assertAllEqual(t2, [[[1, 11], [2, 12]], [[3, 13], [4, 14]]])
1467
1468  def testRepr(self):
1469    st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}})
1470    if context.executing_eagerly():
1471      expected = textwrap.dedent("""
1472          <StructuredTensor(
1473              fields={
1474                  "a": tf.Tensor(5, shape=(), dtype=int32),
1475                  "b": <StructuredTensor(
1476                          fields={
1477                              "c": tf.Tensor([1 2 3], shape=(3,), dtype=int32)},
1478                          shape=())>},
1479              shape=())>""")[1:]
1480    else:
1481      expected = textwrap.dedent("""
1482          <StructuredTensor(
1483              fields={
1484                  "a": Tensor("Const:0", shape=(), dtype=int32),
1485                  "b": <StructuredTensor(
1486                          fields={
1487                              "c": Tensor("RaggedConstant/Const:0", shape=(3,), dtype=int32)},
1488                          shape=())>},
1489              shape=())>""")[1:]
1490    self.assertEqual(repr(st), expected)
1491
1492  def testPartitionOuterDimension2DDenseField(self):
1493    struct = structured_tensor.StructuredTensor.from_fields(
1494        fields={"r": array_ops.constant([[1, 2], [3, 4]])}, shape=[2])
1495
1496    result = struct.partition_outer_dimension(
1497        row_partition.RowPartition.from_uniform_row_length(2, 2))
1498    r = result.field_value("r")
1499    self.assertAllEqual(r, [[[1, 2], [3, 4]]])
1500
1501  @parameterized.parameters([
1502      # Simple example.
1503      (
1504          {"a": 12, "b": 23},
1505          {"a": 7},
1506      ),
1507      # New field.
1508      (
1509          {"a": 12},
1510          {("b",): 13},
1511      ),
1512      # Nested example.
1513      (
1514          {"a": 12, "b": {"c": 23}},
1515          {("b", "c"): 7},
1516      ),
1517      # Multipe updates.
1518      (
1519          {"a": 12, "b": {"c": 23}},
1520          {"a": 3, ("b", "c"): 7},
1521      ),
1522      # Deep updates.
1523      (
1524          {"a": 12, "b": {"c": 23, "d": {"e": 11}}},
1525          {("b", "c"): 7, ("b", "d", "e"): 13},
1526      ),
1527      # Multiple updates to the same substructure.
1528      (
1529          {"a": 12, "b": {"c": 23, "d": {"e": 11}}},
1530          {("b", "c"): 7, ("b", "f"): 13},
1531      ),
1532      # Scalar to non-scalar elements. Shape remains unchanged.
1533      (
1534          {"a": 5},
1535          {"a": ragged_factory_ops.constant_value([[51, 52], [61, 62, 63]])},
1536      ),
1537      # Non-scalar element to scalar.
1538      (
1539          {"c": {"a": [5, 3], "b": 2}},
1540          {("c", "a"): 5},
1541      ),
1542      # Rank-1 StructuredTensor: shape is preserved and an item is added.
1543      (
1544          [{"a": 5}, {"a": 6}],
1545          {"a": [15, 16], "b": np.array([0.9, 1.1])},
1546      ),
1547      # Non-scalar ragged elements, within a rank-2 StructuredTensor: elements
1548      # rows (inner dimensions) are changed, but StructuredTensor shape
1549      # (outer dimensions) are preserved.
1550      (
1551          [[{"a": [5]}], [{"a": [3, 4]}, {"a": [8]}]],
1552          {"a": ragged_factory_ops.constant_value([[[50, 60]], [[30], []]])},
1553      ),
1554  ])  # pyformat: disable
1555  def testWithUpdatesValues(self, pyval, updates):
1556    st = StructuredTensor.from_pyval(pyval)
1557    updated_st = st.with_updates(updates, validate=False)
1558    for key, value in updates.items():
1559      got = updated_st.field_value(key)
1560      self.assertAllEqual(
1561          value, got,
1562          "Update failed: key={}, value={}, got={}".format(key, value, got))
1563
1564  def testWithUpdatesFunctions(self):
1565    pyval = {"a": 12, "b": {"c": 23, "d": {"e": 11}}}
1566    st = StructuredTensor.from_pyval(pyval)
1567    st_updated = st.with_updates(
1568        {
1569            "a": lambda x: x + 1,
1570            ("b", "d", "e"): lambda x: x + 7
1571        }, validate=True)
1572    # Updated values.
1573    self.assertAllEqual(st_updated.field_value("a"), 13)
1574    self.assertAllEqual(st_updated.field_value(("b", "d", "e")), 18)
1575    # Unchanged value.
1576    self.assertAllEqual(st_updated.field_value(("b", "c")), 23)
1577
1578  def test_from_pyval_list_of_empty(self):
1579    """See b/183245576."""
1580    st = structured_tensor.StructuredTensor.from_pyval([{}])
1581    self.assertAllEqual([1], st.shape.as_list())
1582
1583  def test_from_pyval_list_of_empty_three(self):
1584    """See b/183245576."""
1585    st = structured_tensor.StructuredTensor.from_pyval([{}, {}, {}])
1586    self.assertAllEqual([3], st.shape.as_list())
1587    self.assertEmpty(st.field_names())
1588
1589  def test_from_pyval_deep_list_of_empty(self):
1590    """See b/183245576."""
1591    st = structured_tensor.StructuredTensor.from_pyval([[{
1592        "a": {},
1593        "b": [3, 4]
1594    }, {
1595        "a": {},
1596        "b": [5]
1597    }], [{
1598        "a": {},
1599        "b": [7, 8, 9]
1600    }]])
1601    self.assertAllEqual(2, st.rank)
1602    self.assertEqual(2, st.shape[0])
1603    self.assertEmpty(st.field_value("a").field_names())
1604
1605  def testWithUpdatesChecks(self):
1606    pyval = {"a": 12, "b": {"c": 23, "d": {"e": 11}}}
1607    st = StructuredTensor.from_pyval(pyval)
1608
1609    # Try to set non-existant sub-structure.
1610    with self.assertRaisesRegex(
1611        ValueError, r"cannot create new sub-field.*\('b', 'x'\).*is not set"):
1612      st.with_updates({("b", "x", "e"): 5})
1613
1614    # Try to set with path to a non-sub-structure.
1615    with self.assertRaisesRegex(
1616        ValueError, r"cannot create new sub-field.*\('b', 'c'\).*is not a "
1617        r"`StructuredTensor`"):
1618      st.with_updates({("b", "c", "e"): 5})
1619
1620    # Try to apply function to non-existing value.
1621    with self.assertRaisesRegex(
1622        ValueError, r"cannot update.*\('b', 'd', 'x'\).*does not already "
1623        r"exist"):
1624      st.with_updates({("b", "d", "x"): lambda x: x + 1})
1625
1626    # Empty names not allowed.
1627    with self.assertRaisesRegex(ValueError, r"does not allow empty names"):
1628      st.with_updates({(): lambda x: x + 1})
1629    with self.assertRaisesRegex(ValueError, r"does not allow empty names"):
1630      st.with_updates({("b", ""): lambda x: x + 1})
1631
1632    # Parent and child nodes cannot be updated simultaneously.
1633    with self.assertRaisesRegex(
1634        ValueError, r"does not allow both parent and child nodes.*"
1635        r"parent=\('b'.*child=\('b', 'd'"):
1636      st.with_updates({("b", "d"): lambda x: x + 1, "a": 3, "b": 10})
1637
1638    # Invalid shape change.
1639    with self.assertRaisesRegex(
1640        ValueError,
1641        r"`StructuredTensor.with_updates` failed for field \('c',\)"):
1642      st_with_shape = StructuredTensor.from_pyval([[{
1643          "c": {
1644              "a": 5,
1645              "b": 2
1646          }
1647      }], [{
1648          "c": {
1649              "a": 3,
1650              "b": 1
1651          }
1652      }, {
1653          "c": {
1654              "a": 8,
1655              "b": 18
1656          }
1657      }]])
1658      st_with_shape.with_updates({("c", "a"): 3})
1659
1660  def testWithUpdatesDelete(self):
1661    pyval = {"a": 12, "b": {"c": 23, "d": {"e": 11}}}
1662    st = StructuredTensor.from_pyval(pyval)
1663    updated_st = st.with_updates({("b", "c"): None}, validate=True)
1664    self.assertNotIn("c", updated_st.field_value("b").field_names())
1665    with self.assertRaisesRegex(ValueError,
1666                                r"cannot delete.*\('b', 'x'\).*not present"):
1667      st.with_updates({("b", "x"): None}, validate=True)
1668    with self.assertRaisesRegex(ValueError,
1669                                r"cannot delete.*\'x'.*not present"):
1670      st.with_updates({"x": None}, validate=False)
1671
1672    # Test that nrows() and rowpartitions() is preserved after removal.
1673    pyval = [[{"a": 1}, {"a": 2}], [{"a": 3}]]
1674    st = StructuredTensor.from_pyval(pyval)
1675    self.assertLen(st.row_partitions, 1)
1676    self.assertAllEqual(st.nrows(), 2)
1677    self.assertAllEqual(st.row_partitions[0].row_lengths(), [2, 1])
1678    updated_st = st.with_updates({("a",): None}, validate=True)
1679    self.assertLen(updated_st.row_partitions, 1)
1680    self.assertAllEqual(updated_st.nrows(), 2)
1681    self.assertAllEqual(updated_st.row_partitions[0].row_lengths(), [2, 1])
1682
1683    # Test that it works also for rank-1 and rank-0 empty results.
1684    pyval = [{"a": 1}, {"a": 2}]
1685    st = StructuredTensor.from_pyval(pyval)
1686    self.assertEqual(st.rank, 1)
1687    updated_st = st.with_updates({("a",): None}, validate=True)
1688    self.assertEqual(updated_st.rank, 1)
1689
1690    # assertEqual won't work because nrows() returns a tensor, and
1691    # assertEqual doesn't do the magic to convert them to numbers in a
1692    # way that works in eager/non-eager mode.
1693    self.assertAllEqual(updated_st.nrows(), 2)
1694    pyval = {"a": [0, 1]}
1695    st = StructuredTensor.from_pyval(pyval)
1696    self.assertEqual(st.rank, 0)
1697    updated_st = st.with_updates({("a",): None}, validate=True)
1698    self.assertEqual(updated_st.rank, 0)
1699    self.assertFalse(updated_st.row_partitions)
1700    self.assertIsNone(updated_st.nrows())
1701
1702  def test_from_pyval_deep_row_partitions(self):
1703    """See b/179195750."""
1704    st = structured_tensor.StructuredTensor.from_pyval([{
1705        "foo": [{
1706            "bar": [{
1707                "baz": [b"FW"]
1708            }]
1709        }]
1710    }])
1711    st2 = st.field_value(("foo", "bar"))
1712    self.assertLen(st2.row_partitions, st2.rank - 1)
1713
1714  def test_from_fields_deep_row_partitions(self):
1715    """Test a field with its own row_partition. See b/179195750."""
1716    st = structured_tensor.StructuredTensor.from_pyval([[[{"baz": [b"FW"]}]]])
1717    self.assertLen(st.row_partitions, st.rank - 1)
1718    st2 = structured_tensor.StructuredTensor.from_fields(
1719        fields={"bar": st}, shape=(None, None), validate=False)
1720    st3 = st2.field_value("bar")
1721    self.assertLen(st3.row_partitions, st3.rank - 1)
1722
1723  def test_structured_tensor_spec_shape_property(self):
1724    spec = StructuredTensor.Spec._from_shape(DynamicRaggedShape.Spec(
1725        row_partitions=[],
1726        static_inner_shape=[1, 2],
1727        dtype=dtypes.int64))
1728    self.assertEqual(spec.shape.as_list(), [1, 2])
1729    spec = StructuredTensor.Spec._from_shape(DynamicRaggedShape.Spec(
1730        row_partitions=[],
1731        static_inner_shape=[None],
1732        dtype=dtypes.int64))
1733    self.assertEqual(spec.shape.as_list(), [None])
1734
1735  def test_dynamic_ragged_shape_init_vector(self):
1736    x = constant_op.constant([1, 2, 3, 4])
1737    y = constant_op.constant([[1, 2], [3, 4], [5, 6], [7, 8]])
1738    fields = {"x": x, "y": y}
1739    nrows = constant_op.constant(4)
1740    shape = tensor_shape.TensorShape((4,))
1741    row_partitions = ()
1742    rs = structured_tensor_dynamic._dynamic_ragged_shape_init(
1743        fields, shape, nrows, row_partitions)
1744    self.assertEqual(
1745        repr(rs._to_tensor_shape()), repr(tensor_shape.TensorShape((4,))))
1746
1747  def test_dynamic_ragged_shape_init_scalar(self):
1748    x = constant_op.constant([1, 2, 3, 4])
1749    y = constant_op.constant([[1, 2], [3, 4], [5, 6], [7, 8]])
1750    fields = {"x": x, "y": y}
1751    nrows = None
1752    shape = tensor_shape.TensorShape(())
1753    row_partitions = ()
1754
1755    rs = structured_tensor_dynamic._dynamic_ragged_shape_init(
1756        fields, shape, nrows, row_partitions)
1757    self.assertEqual(
1758        repr(rs._to_tensor_shape()), repr(tensor_shape.TensorShape(())))
1759
1760  def test_dynamic_ragged_shape_init_ragged(self):
1761    x = ragged_factory_ops.constant_value([[1, 2, 3], [4]])
1762    fields = {"x": x}
1763    nrows = constant_op.constant(2, dtype=dtypes.int64)
1764    shape = tensor_shape.TensorShape([2, None])
1765    row_partitions = tuple(x._nested_row_partitions)
1766    rs = structured_tensor_dynamic._dynamic_ragged_shape_init(
1767        fields, shape, nrows, row_partitions)
1768    self.assertEqual(
1769        repr(rs._to_tensor_shape()), repr(tensor_shape.TensorShape((2, None))))
1770
1771
1772if __name__ == "__main__":
1773  googletest.main()
1774