xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/ragged/dynamic_ragged_shape_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.ragged.dynamic_ragged_shape."""
16
17from typing import Sequence, Union
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.python.client import session
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.eager import context
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors_impl
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.framework import test_util
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import gen_math_ops
36from tensorflow.python.ops import gradients_impl
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import string_ops
39from tensorflow.python.ops.ragged import dynamic_ragged_shape
40from tensorflow.python.ops.ragged import ragged_array_ops
41from tensorflow.python.ops.ragged import ragged_factory_ops
42from tensorflow.python.ops.ragged import ragged_tensor
43from tensorflow.python.ops.ragged.dynamic_ragged_shape import _LayerBroadcaster
44from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape
45from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
46from tensorflow.python.ops.ragged.row_partition import RowPartition
47from tensorflow.python.ops.ragged.row_partition import RowPartitionSpec
48from tensorflow.python.platform import googletest
49
50
51def _to_row_partitions_from_lengths(
52    lengths: Sequence[Union[int, Sequence[int]]]) -> Sequence[RowPartition]:
53  """Allow ragged and uniform shapes to be specified.
54
55  For example, [2, [2,1], 2] represents a shape like:
56  [[[0, 0], [0, 0]], [[0, 0]]]
57
58  Args:
59    lengths: a list of integers and lists of integers.
60
61  Returns:
62    a sequence of RowPartitions.
63  """
64  (result,
65   _) = dynamic_ragged_shape._to_row_partitions_and_nvals_from_lengths(lengths)
66  return result
67
68
69def _to_ragged_tensor_from_lengths(
70    values, lengths: Sequence[Union[int, Sequence[int]]]) -> RaggedTensor:
71  """Specify a ragged tensor (or tensor) from lengths and values."""
72  row_partitions = _to_row_partitions_from_lengths(lengths)
73  values = constant_op.constant(values)
74  if not row_partitions:
75    return values
76  return RaggedTensor._from_nested_row_partitions(values, row_partitions)
77
78
79def _divides(a, b):
80  return b % a == 0
81
82
83def _next_prime(primes_so_far):
84  first_candidate = 2
85  if primes_so_far:
86    first_candidate = primes_so_far[-1] + 1
87  while True:
88    if not any([_divides(x, first_candidate) for x in primes_so_far]):
89      return first_candidate
90    first_candidate = first_candidate + 1
91
92
93def _lowest_primes(n):
94  """Give the lowest n primes."""
95  result = []
96  for _ in range(n):
97    result.append(_next_prime(result))
98  return result
99
100
101def _num_elements_of_lengths_with_rows(rows,
102                                       lengths: Sequence[Union[int,
103                                                               Sequence[int]]]):
104  """Helper function for _num_elements_of_lengths."""
105  if not lengths:
106    return rows
107  next_length = lengths[0]
108  if isinstance(next_length, int):
109    return _num_elements_of_lengths_with_rows(next_length * rows, lengths[1:])
110  else:
111    return _num_elements_of_lengths_with_rows(sum(next_length), lengths[1:])
112
113
114def _num_elements_of_lengths(lengths: Sequence[Union[int, Sequence[int]]]):
115  """Static version of DynamicRaggedShape.from_lengths(lengths)._num_elements()."""
116  return _num_elements_of_lengths_with_rows(1, lengths)
117
118
119def _to_prime_tensor_from_lengths(
120    lengths: Sequence[Union[int, Sequence[int]]]) -> RaggedTensor:
121  """Create a tensor of primes with the shape specified."""
122  shape = DynamicRaggedShape.from_lengths(lengths)
123  num_elements = _num_elements_of_lengths(lengths)
124  return ragged_array_ops.ragged_reshape(_lowest_primes(num_elements), shape)
125
126
127@test_util.run_all_in_graph_and_eager_modes
128class DynamicRaggedShapeTest(test_util.TensorFlowTestCase,
129                             parameterized.TestCase):
130
131  def assertRowPartitionEq(self,
132                           x: RowPartition,
133                           y: RowPartition,
134                           msg=None) -> None:
135    self.assertAllEqual(x.row_splits(), y.row_splits(), msg=msg)
136
137  def assertShapeEq(self,
138                    x: DynamicRaggedShape,
139                    y: DynamicRaggedShape,
140                    msg=None) -> None:
141    assert isinstance(x, DynamicRaggedShape)
142    assert isinstance(y, DynamicRaggedShape)
143    if msg is None:
144      msg = ''
145    self.assertLen(
146        x.row_partitions, len(y.row_partitions), msg=msg + ': length unequal')
147    for i in range(len(x.row_partitions)):
148      x_dims = x.row_partitions[i]
149      y_dims = y.row_partitions[i]
150      self.assertRowPartitionEq(
151          x_dims, y_dims, msg=msg + ': row_partition ' + str(i))
152    self.assertAllEqual(
153        x.inner_shape, y.inner_shape, msg=msg + ': shapes unequal')
154
155  def assertLayerBroadcasterEq(self, x: _LayerBroadcaster,
156                               y: _LayerBroadcaster) -> None:
157    assert isinstance(x, _LayerBroadcaster)
158    assert isinstance(y, _LayerBroadcaster)
159    self.assertAllEqual(x.gather_index, y.gather_index)
160
161  def assertBroadcasterEq(self, x: dynamic_ragged_shape._Broadcaster,
162                          y: dynamic_ragged_shape._Broadcaster) -> None:
163    assert isinstance(x, dynamic_ragged_shape._Broadcaster)
164    assert isinstance(y, dynamic_ragged_shape._Broadcaster)
165    self.assertShapeEq(x.source_shape, y.source_shape)
166    self.assertShapeEq(x.target_shape, y.target_shape)
167    self.assertLen(x._layer_broadcasters, len(y._layer_broadcasters))
168    for x_layer, y_layer in zip(x._layer_broadcasters, y._layer_broadcasters):
169      self.assertLayerBroadcasterEq(x_layer, y_layer)
170
171  @parameterized.parameters([
172      dict(value='x', row_partitions=[], inner_shape=()),
173      dict(value=['a', 'b', 'c'], row_partitions=[], inner_shape=[3]),
174      dict(
175          value=[['a', 'b', 'c'], ['d', 'e', 'f']],
176          row_partitions=(),
177          inner_shape=[2, 3]),
178      dict(
179          value=[[['a', 'b', 'c'], ['d', 'e', 'f']]],
180          row_partitions=(),
181          inner_shape=[1, 2, 3]),
182      dict(
183          value=ragged_factory_ops.constant_value([['a', 'b', 'c'], ['d', 'e']],
184                                                  ragged_rank=1),
185          row_partitions=[[0, 3, 5]],
186          inner_shape=[5]),
187      dict(
188          value=ragged_factory_ops.constant_value(
189              [[['a', 'b', 'c'], ['d', 'e', 'f']]], ragged_rank=1),
190          row_partitions=[[0, 2]],
191          inner_shape=[2, 3]),
192      dict(
193          value=ragged_factory_ops.constant_value(
194              [[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1),
195          row_partitions=[[0, 2, 3]],
196          inner_shape=[3, 2, 1]),
197      dict(
198          value=ragged_factory_ops.constant_value([[10, 20], [30]]),
199          row_partitions=[[0, 2, 3]],
200          inner_shape=[3]),
201      # Docstring examples:
202      dict(value=[[1, 2, 3], [4, 5, 6]], row_partitions=[], inner_shape=[2, 3]),
203      dict(
204          value=ragged_factory_ops.constant_value([[1, 2], [], [3, 4, 5]]),
205          row_partitions=[[0, 2, 2, 5]],
206          inner_shape=[5]),
207      dict(
208          value=ragged_factory_ops.constant_value([[[1, 2], [3, 4]], [[5, 6]]],
209                                                  ragged_rank=1),
210          row_partitions=[[0, 2, 3]],
211          inner_shape=[3, 2]),
212      dict(
213          value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
214          row_partitions=[[0, 2, 3], [0, 2, 3, 5]],
215          inner_shape=[5]),
216  ])
217  def testFromTensor(self, value, row_partitions, inner_shape):
218    shape = DynamicRaggedShape.from_tensor(value)
219    row_partitions = [RowPartition.from_row_splits(x) for x in row_partitions]
220    expected = DynamicRaggedShape(row_partitions, inner_shape)
221    self.assertShapeEq(shape, expected)
222
223  # pylint:disable=g-long-lambda
224  @parameterized.parameters([
225      # from_lengths           | row_partitions            | inner_shape
226      # ---------------------- | --------------------------| -------------
227      # []                     | []                        | []
228      # [2, (3, 2)]            | [RP([3, 2])]              | [5]
229      # [2, 2]                 | []                        | [2, 2]
230      # [2, (3, 2), 7]         | [RP([3, 2])]              | [5, 7]
231      # [2, (2, 2), 3]         | [RP([2, 2])]              | [4, 3]
232      # [2, 2, 3]              | []                        | [2, 2, 3]
233      # [2, (2, 1), (2, 0, 3)] | [RP(2, 1), RP([2, 0, 3])] | [5]
234
235      dict(lengths=[], row_partitions=[], inner_shape=[]),
236      dict(
237          lengths=[2, (3, 2)],
238          row_partitions=lambda: [RowPartition.from_row_lengths([3, 2])],
239          inner_shape=[5]),
240      dict(lengths=[2, 2], row_partitions=[], inner_shape=[2, 2]),
241      dict(
242          lengths=[2, (3, 2), 7],
243          row_partitions=lambda: [RowPartition.from_row_lengths([3, 2])],
244          inner_shape=[5, 7]),
245      dict(
246          lengths=[2, (2, 2), 3],
247          row_partitions=lambda: [RowPartition.from_row_lengths([2, 2])],
248          inner_shape=[4, 3]),
249      dict(lengths=[2, 2, 3], row_partitions=[], inner_shape=[2, 2, 3]),
250      dict(
251          lengths=[2, (2, 1), (2, 0, 3)],
252          row_partitions=lambda: [
253              RowPartition.from_row_lengths([2, 1]),
254              RowPartition.from_row_lengths([2, 0, 3])
255          ],
256          inner_shape=[5]),
257      # from_lengths   | num_row    | row_partitions           | inner_shape
258      #                : partitions :                          :
259      # ---------------| -----------|--------------------------|------------
260      # [2, (3, 2), 2] | 2          | [RP([3, 2]), URP(2, 10)] | [10]
261      # [2, 2]         | 1          | [URP(2, 4)]              | [4]
262      # [2, 2, 3]      | 0          | []                       | [2, 2, 3]
263      # [2, 2, 3]      | 1          | [URP(2, 4)]              | [4, 3]
264      # [2, 2, 3]      | 2          | [URP(2, 4), URP(3, 12)]  | [12]
265      dict(lengths=[2, (3, 2), 2],
266           num_row_partitions=2,
267           row_partitions=lambda: [RowPartition.from_row_lengths([3, 2]),
268                                   RowPartition.from_uniform_row_length(2, 10)],
269           inner_shape=[10]),
270      dict(lengths=[2, 2],
271           num_row_partitions=1,
272           row_partitions=lambda: [RowPartition.from_uniform_row_length(2, 4)],
273           inner_shape=[4]),
274      dict(lengths=[2, 2, 3],
275           num_row_partitions=0,
276           row_partitions=[],
277           inner_shape=[2, 2, 3]),
278      dict(lengths=[2, 2, 3],
279           num_row_partitions=1,
280           row_partitions=lambda: [RowPartition.from_uniform_row_length(2, 4)],
281           inner_shape=[4, 3]),
282      dict(lengths=[2, 2, 3],
283           num_row_partitions=2,
284           row_partitions=lambda: [RowPartition.from_uniform_row_length(2, 4),
285                                   RowPartition.from_uniform_row_length(3, 12)],
286           inner_shape=[12])
287  ])
288  def testFromLengths(self,
289                      lengths,
290                      row_partitions,
291                      inner_shape,
292                      num_row_partitions=None):
293    if callable(row_partitions):
294      row_partitions = row_partitions()
295    shape = DynamicRaggedShape.from_lengths(
296        lengths, num_row_partitions=num_row_partitions)
297    expected = DynamicRaggedShape(row_partitions, inner_shape)
298    self.assertShapeEq(shape, expected)
299
300  @parameterized.parameters([
301      dict(
302          lengths=[2, (2, 1, 3)],
303          num_row_partitions=1,
304          msg='Shape not consistent'),
305      dict(
306          lengths=[2, 3],
307          num_row_partitions=2,
308          msg='num_row_partitions should be less than'),
309      dict(
310          lengths=[],
311          num_row_partitions=3,
312          msg='num_row_partitions==0 for a scalar shape'),
313      dict(
314          lengths=[(5, 3), 3],
315          num_row_partitions='a',
316          msg='num_row_partitions should be an int or None'),
317      dict(
318          lengths=[(5, 'a'), 3],
319          num_row_partitions=0,
320          msg='element of lengths should be int or tuple of ints'),
321      dict(
322          lengths=['a'],
323          num_row_partitions=0,
324          msg='element of lengths should be int or tuple of ints'),
325      dict(lengths=7, num_row_partitions=0, msg='lengths should be a list')
326  ])
327  def testFromLengthsError(self, lengths, msg, num_row_partitions=None):
328    with self.assertRaisesRegex(ValueError, msg):
329      DynamicRaggedShape.from_lengths(
330          lengths, num_row_partitions=num_row_partitions)
331
332  def testGetItemSliceRankUnknownA(self):
333    if not context.executing_eagerly():
334      original_t = array_ops.placeholder_with_default(np.array([4, 5, 3]), None)
335      sh = DynamicRaggedShape.from_tensor(original_t)
336      known = sh[:1]
337      self.assertIsNone(known.rank)
338
339  def testGetItemSliceRankUnknownLong(self):
340    if not context.executing_eagerly():
341      original_t = array_ops.placeholder_with_default(np.array([4, 5, 3]), None)
342      sh = DynamicRaggedShape.from_tensor(original_t)
343      unknown = sh[:20]
344      self.assertIsNone(unknown.rank)
345
346  def testGetItemSliceRankKnownLong(self):
347    if not context.executing_eagerly():
348      original_t = constant_op.constant([4, 5, 3], dtypes.float32)
349      sh = DynamicRaggedShape.from_tensor(original_t)
350      unknown = sh[:20]
351      self.assertEqual(unknown.rank, 1)
352
353  def testGetBroadcaster(self):
354    origin_shape = DynamicRaggedShape(
355        [RowPartition.from_uniform_row_length(1, 3)], inner_shape=[3])
356    dest_shape = DynamicRaggedShape(
357        [RowPartition.from_uniform_row_length(2, 6)], inner_shape=[6])
358    actual = dynamic_ragged_shape._get_broadcaster(origin_shape, dest_shape)
359    expected = dynamic_ragged_shape._Broadcaster(origin_shape, dest_shape, [
360        _LayerBroadcaster.from_gather_index([0, 1, 2]),
361        _LayerBroadcaster.from_gather_index([0, 0, 1, 1, 2, 2])
362    ])
363    self.assertBroadcasterEq(actual, expected)
364
365  def testGetBroadcaster2(self):
366    origin_shape = DynamicRaggedShape([], inner_shape=[])
367    dest_shape = DynamicRaggedShape([RowPartition.from_row_splits([0, 2, 3])],
368                                    inner_shape=[3])
369    actual = dynamic_ragged_shape._get_broadcaster(origin_shape, dest_shape)
370    expected = dynamic_ragged_shape._Broadcaster(origin_shape, dest_shape, [])
371    self.assertBroadcasterEq(actual, expected)
372
373  @parameterized.parameters([
374      dict(lengths=[2, 3], axis=0, expected=2),
375      dict(lengths=[2, 3], axis=1, expected=6),
376      dict(lengths=[2, 3], axis=-1, expected=6),
377      dict(lengths=[2, 3], axis=-2, expected=2),
378      dict(lengths=[2, 3, 4], axis=0, expected=2),
379      dict(lengths=[2, 3, 4], axis=1, expected=6),
380      dict(lengths=[2, 3, 4], axis=2, expected=24),
381      dict(lengths=[2, 3, 4], axis=-1, expected=24),
382      dict(lengths=[2, 3, 4], axis=-2, expected=6),
383      dict(lengths=[2, 3, 4], axis=-3, expected=2),
384      dict(lengths=[2, (2, 3), 7], axis=0, expected=2),
385      dict(lengths=[2, (2, 3), 7], axis=1, expected=5),
386      dict(lengths=[2, (2, 3), 7], axis=2, expected=35),
387      dict(lengths=[2, (2, 3), 7], axis=-1, expected=35),
388      dict(lengths=[2, (2, 3), 7], axis=-2, expected=5),
389      dict(lengths=[2, (2, 3), 7], axis=-3, expected=2),
390  ])
391  def testNumSlicesInDimension(self, lengths, axis, expected):
392    original = DynamicRaggedShape.from_lengths(lengths)
393    actual = original._num_slices_in_dimension(axis)
394    self.assertAllEqual(expected, actual)
395
396  @parameterized.parameters([
397      dict(
398          lengths=[2, 3],
399          axis=0.5,
400          error_type=TypeError,
401          error_regex='axis must be an integer'),
402  ])
403  def testNumSlicesInDimensionRaises(self, lengths, axis, error_type,
404                                     error_regex):
405    original = DynamicRaggedShape.from_lengths(lengths)
406    with self.assertRaisesRegex(error_type, error_regex):
407      original._num_slices_in_dimension(axis)
408
409  @parameterized.parameters([
410      dict(
411          lengths=[2, (1, 2), 4],
412          new_dense_rank=3,
413          error_type=ValueError,
414          error_regex='Cannot get an inner shape'),
415      dict(
416          lengths=[],
417          new_dense_rank=3,
418          error_type=ValueError,
419          error_regex='old inner_rank cannot be zero'),
420      dict(
421          lengths=[2, 3],
422          new_dense_rank=0,
423          error_type=ValueError,
424          error_regex='new_inner_rank cannot be zero'),
425  ])
426  def testAltInnerShapeRaises(self, lengths, new_dense_rank, error_type,
427                              error_regex):
428    original = DynamicRaggedShape.from_lengths(lengths)
429    with self.assertRaisesRegex(error_type, error_regex):
430      original._alt_inner_shape(new_dense_rank)
431
432  @parameterized.parameters([
433      dict(
434          lengths=[2, (1, 2), 4], new_dense_rank=2, expected_inner_shape=[3,
435                                                                          4]),
436  ])
437  def testAltInnerShape(self, lengths, new_dense_rank, expected_inner_shape):
438    original = DynamicRaggedShape.from_lengths(lengths)
439    actual = original._alt_inner_shape(new_dense_rank)
440    self.assertAllEqual(actual, expected_inner_shape)
441
442  def testWithNumRowPartitionsDynamic(self):
443    @def_function.function(
444        input_signature=[tensor_spec.TensorSpec([3], dtypes.int64)])
445    def fun(x):
446      shape = DynamicRaggedShape([
447          RowPartition.from_row_lengths([1, 3], dtype=dtypes.int64),
448          RowPartition.from_row_lengths([2, 3, 4, 5], dtype=dtypes.int64)
449      ], x)
450      result = shape._with_num_row_partitions(3)
451      expected = DynamicRaggedShape([
452          RowPartition.from_row_lengths([1, 3], dtype=dtypes.int64),
453          RowPartition.from_row_lengths([2, 3, 4, 5], dtype=dtypes.int64),
454          RowPartition.from_uniform_row_length(
455              2, nrows=14, nvals=28, dtype=dtypes.int64)
456      ], [14 * 2, 3])
457      self.assertShapeEq(expected, result)
458    fun(constant_op.constant([14, 2, 3], dtype=dtypes.int64))
459
460  @parameterized.parameters([
461      dict(
462          lengths=[2],
463          new_dense_rank=2,
464          error_type=ValueError,
465          error_regex='Cannot change inner_rank if'),
466  ])
467  def testWithDenseRankRaises(self, lengths, new_dense_rank, error_type,
468                              error_regex):
469    original = DynamicRaggedShape.from_lengths(lengths)
470    with self.assertRaisesRegex(error_type, error_regex):
471      original._with_inner_rank(new_dense_rank)
472
473  @parameterized.parameters([
474      dict(
475          lengths=[2, (1, 2)],
476          num_row_partitions=2,
477          error_type=ValueError,
478          error_regex='num_row_partitions must be less than rank'),
479      dict(
480          lengths=[2],
481          num_row_partitions=-1,
482          error_type=ValueError,
483          error_regex='num_row_partitions must be nonnegative'),
484      dict(
485          lengths=[2],
486          num_row_partitions=0.5,
487          error_type=ValueError,
488          error_regex='num_row_partitions must be an int'),
489  ])
490  def testWithNumRowPartitionsRaises(self, lengths, num_row_partitions,
491                                     error_type, error_regex):
492    original = DynamicRaggedShape.from_lengths(lengths)
493    with self.assertRaisesRegex(error_type, error_regex):
494      original._with_num_row_partitions(num_row_partitions)
495
496  def testDimensionRaises(self):
497    original = DynamicRaggedShape.from_lengths([2, (1, 2)])
498    with self.assertRaisesRegex(TypeError, 'index should be an int'):
499      # This error is not exposed directly to the end user.
500      original._dimension(0.5)
501
502  @parameterized.parameters([
503      # The whole shape (num_row_partitions=0, start=negative, stop=really big)
504      dict(lengths=[2, 3], s=slice(-1000, 100), expected_lengths=[2, 3]),
505      # The whole shape (num_row_partitions=0, stop=really big)
506      dict(lengths=[2, 3], s=slice(0, 100), expected_lengths=[2, 3]),
507      # The whole shape (num_row_partitions=0, stop=None)
508      dict(lengths=[2, 3], s=slice(0, None), expected_lengths=[2, 3]),
509      # start = None, num_row_partitions=1, stop = 3 < rank = 4
510      dict(
511          lengths=[2, (1, 2), 3, 4],
512          s=slice(None, 3),
513          expected_lengths=[2, (1, 2), 3]),
514      # start = 1, num_row_partitions=1, stop = 4, rank = 4
515      dict(
516          lengths=[2, 3, 3, 4],
517          num_row_partitions=1,
518          s=slice(1, 4),
519          expected_lengths=[3, 3, 4]),
520      # start = 1, num_row_partitions=1, stop = 3 < rank = 4
521      dict(
522          lengths=[2, 3, 3, 4],
523          num_row_partitions=1,
524          s=slice(1, 3),
525          expected_lengths=[3, 3]),
526      # start = 1, num_row_partitions=2, stop = 3 < rank = 4
527      dict(
528          lengths=[2, 3, 4, 3, 4],
529          num_row_partitions=2,
530          s=slice(1, 3),
531          expected_lengths=[3, 4]),
532      # start = 0, num_row_partitions=1, stop = 3 < rank = 4
533      dict(
534          lengths=[2, (1, 2), 3, 4],
535          s=slice(0, 3),
536          expected_lengths=[2, (1, 2), 3]),
537      # start = 0, num_row_partitions=0, stop < rank
538      dict(lengths=[2, 3, 4], s=slice(0, 2), expected_lengths=[2, 3]),
539      # start=0 < stop=2 <= num_row_partitions
540      dict(
541          lengths=[2, (1, 2), (3, 4, 5)],
542          s=slice(0, 2),
543          expected_lengths=[2, (1, 2)]),
544      # start=0 < stop=1 <= num_row_partitions
545      dict(lengths=[2, (1, 2), (3, 4, 5)], s=slice(0, 1), expected_lengths=[2]),
546      # Reversed indices, gives scalar shape.
547      dict(lengths=[2, 3], s=slice(2, 0), expected_lengths=[]),
548      # The whole shape (num_row_partitions=0)
549      dict(lengths=[2, 3], s=slice(0, 2), expected_lengths=[2, 3]),
550  ])
551  def testGetItemSlice(self,
552                       lengths,
553                       s,
554                       expected_lengths,
555                       num_row_partitions=None):
556    original = DynamicRaggedShape.from_lengths(lengths)
557    if num_row_partitions is not None:
558      original = original._with_num_row_partitions(num_row_partitions)
559    expected = DynamicRaggedShape.from_lengths(expected_lengths)
560    actual = original[s]
561    self.assertShapeEq(expected, actual)
562
563  @parameterized.parameters([
564      dict(
565          lengths=[2, (1, 2), 3, 4],
566          index=0.5,
567          error_type=TypeError,
568          error_regex='Argument is not an int or a slice'),
569      dict(
570          lengths=[2, (1, 2), 3, 4],
571          index=slice(0, 1, 2),
572          error_type=IndexError,
573          error_regex='Cannot stride through a shape'),
574      dict(
575          lengths=[2, (1, 2), 3, 4],
576          index=1,
577          error_type=ValueError,
578          error_regex='Index 1 is not uniform'),
579      dict(
580          lengths=[2, 3, 3, 4],
581          num_row_partitions=1,
582          index=-20,
583          error_type=IndexError,
584          error_regex='Index must be non-negative'),
585      dict(
586          lengths=[2, 3, 3, 4],
587          num_row_partitions=1,
588          index=9,
589          error_type=IndexError,
590          error_regex='Index is too big'),
591  ])
592  def testGetItemRaisesStatic(self,
593                              lengths,
594                              index,
595                              error_type,
596                              error_regex,
597                              num_row_partitions=None):
598    original = DynamicRaggedShape.from_lengths(lengths)
599    if num_row_partitions is not None:
600      original = original._with_num_row_partitions(num_row_partitions)
601    with self.assertRaisesRegex(error_type, error_regex):
602      original[index]  # pylint: disable=pointless-statement
603
604  def testBroadcastToAlt(self):
605    origin = RaggedTensor.from_uniform_row_length([3, 4, 5],
606                                                  uniform_row_length=1)
607    expected = RaggedTensor.from_uniform_row_length([3, 3, 4, 4, 5, 5],
608                                                    uniform_row_length=2)
609    expected_shape = DynamicRaggedShape.from_tensor(expected)
610    actual = dynamic_ragged_shape.broadcast_to(origin, expected_shape)
611    self.assertAllEqual(actual, expected)
612
613  @parameterized.parameters([
614      dict(
615          source_lengths=[3],
616          target_lengths=[1, 3],
617          target_num_row_partitions=1,
618          expected_gather_indices=[[0, 1, 2]]),
619      dict(  # BroadcastTensorTo4 broadcaster.
620          source_lengths=[2, 3],
621          target_lengths=[1, 2, 3],
622          target_num_row_partitions=2,
623          expected_gather_indices=[[0, 1], [0, 1, 2, 3, 4, 5]]),
624      dict(  # raggedTensor1.
625          source_lengths=[3, (1, 2, 1), 2, 2],
626          source_num_row_partitions=3,
627          target_lengths=[1, 1, 3, (1, 2, 1), 2, 2],
628          target_num_row_partitions=5,
629          expected_gather_indices=[[0, 1, 2], [0, 1, 2, 3],
630                                   [0, 1, 2, 3, 4, 5, 6, 7],
631                                   [
632                                       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
633                                       13, 14, 15
634                                   ]]),
635  ])
636  def testBroadcaster(self,
637                      source_lengths,
638                      target_lengths,
639                      expected_gather_indices,
640                      source_num_row_partitions=None,
641                      target_num_row_partitions=None):
642    source = DynamicRaggedShape.from_lengths(source_lengths)
643    if source_num_row_partitions is not None:
644      source = source._with_num_row_partitions(source_num_row_partitions)
645    target = DynamicRaggedShape.from_lengths(target_lengths)
646    if target_num_row_partitions is not None:
647      target = target._with_num_row_partitions(target_num_row_partitions)
648
649    expected_gather_indices = [
650        _LayerBroadcaster.from_gather_index(x) for x in expected_gather_indices
651    ]
652    actual = dynamic_ragged_shape._get_broadcaster(source, target)
653    expected = dynamic_ragged_shape._Broadcaster(source, target,
654                                                 expected_gather_indices)
655    self.assertBroadcasterEq(actual, expected)
656
657  def testRaggedGradientSimple1(self):
658    if context.executing_eagerly():
659      return
660    def func(x):
661      rt1 = RaggedTensor.from_row_splits(
662          values=x, row_splits=[0, 4, 7, 8], validate=False)
663      rt2 = rt1 * [[10], [100], [1000]]
664      return rt2.flat_values
665
666    x = constant_op.constant([3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0])
667    y = func(x)
668    g = gradients_impl.gradients(ys=y, xs=x)[0]
669
670    self.assertAllClose(ops.convert_to_tensor(g),
671                        [10., 10., 10., 10., 100., 100., 100, 1000.])
672
673  def testRaggedGradientSimple2(self):
674    if context.executing_eagerly():
675      return
676    def func(x):
677      rt1 = RaggedTensor._from_row_partition(
678          x,
679          RowPartition.from_row_splits(row_splits=[0, 4, 7, 8], validate=False))
680      rt2 = rt1 * [[10], [100], [1000]]
681      return rt2.flat_values
682
683    x = constant_op.constant([3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0])
684    y = func(x)
685    g = gradients_impl.gradients(ys=y, xs=x)[0]
686
687    self.assertAllClose(ops.convert_to_tensor(g),
688                        [10., 10., 10., 10., 100., 100., 100, 1000.])
689
690  def testRaggedGradientSimple3(self):
691    if context.executing_eagerly():
692      return
693    def func(x):
694      rt1 = RaggedTensor._from_row_partition(
695          x,
696          RowPartition.from_row_splits(row_splits=[0, 4, 7, 8],
697                                       dtype=dtypes.int32, validate=False))
698      rt2 = rt1 * [[10], [100], [1000]]
699      return rt2.flat_values
700
701    x = constant_op.constant([3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0])
702    y = func(x)
703    g = gradients_impl.gradients(ys=y, xs=x)[0]
704
705    self.assertAllClose(ops.convert_to_tensor(g),
706                        [10., 10., 10., 10., 100., 100., 100, 1000.])
707
708  def testRaggedMul(self):
709    if context.executing_eagerly():
710      return
711    x = constant_op.constant([3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0])
712    rt1 = RaggedTensor._from_row_partition(
713        x,
714        RowPartition.from_row_splits(row_splits=[0, 4, 7, 8],
715                                     dtype=dtypes.int64, validate=False))
716    rt2 = rt1 * [[10], [100], [1000]]
717    self.assertAllClose(rt2.flat_values,
718                        [30.0, 10.0, 40.0, 10.0, 100.0, 0.0, 200.0, 1000.0])
719
720  def testBroadcastToGradient(self):
721    if context.executing_eagerly():
722      return
723    def func(x):
724      target_shape = DynamicRaggedShape.from_row_partitions(
725          [RowPartition.from_row_splits(row_splits=[0, 4, 7, 8])])
726
727      rt = dynamic_ragged_shape.broadcast_to(x, target_shape)
728      return rt.flat_values
729
730    x = constant_op.constant([[3.0], [1.0], [4.0]])
731    y = func(x)
732    g = gradients_impl.gradients(ys=y, xs=x)[0]
733
734    self.assertAllClose(g, [[4.], [3.], [1.]])
735
736  def testBroadcastScalarToScalar(self):
737    origin = constant_op.constant(b'x')
738    expected = origin
739    expected_shape = DynamicRaggedShape.from_tensor(expected)
740    actual = dynamic_ragged_shape.broadcast_to(origin, expected_shape)
741    self.assertAllEqual(actual, expected)
742
743  @parameterized.parameters([
744      dict(lengths=[2, 3], axis=0),
745      dict(lengths=[2, 3], axis=1),
746      dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=0),
747      dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=2),
748      dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=3),
749  ])
750  def testIsUniformTrue(self, lengths, axis, num_row_partitions=None):
751    shape = DynamicRaggedShape.from_lengths(lengths)
752    if num_row_partitions is not None:
753      shape = shape._with_num_row_partitions(num_row_partitions)
754    actual = shape.is_uniform(axis)
755    self.assertTrue(actual)
756
757  @parameterized.parameters([
758      dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=1),
759      dict(
760          lengths=[2, (2, 3), 2, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 4],
761          num_row_partitions=3,
762          axis=3),
763  ])
764  def testIsUniformFalse(self, lengths, num_row_partitions, axis):
765    shape = DynamicRaggedShape.from_lengths(lengths)._with_num_row_partitions(
766        num_row_partitions)
767    actual = shape.is_uniform(axis)
768    self.assertFalse(actual)
769
770  @parameterized.parameters([
771      dict(
772          lengths=[2, (2, 3), 7, 4],
773          num_row_partitions=2,
774          axis=10,
775          error_type=IndexError,
776          error_regex='Expected axis=10 < rank=4'),
777      dict(
778          lengths=[2, (2, 3), 7, 4],
779          num_row_partitions=2,
780          axis=-1,
781          error_type=IndexError,
782          error_regex='Negative axis values are not supported'),
783      dict(
784          lengths=[2, (2, 3), 7, 4],
785          num_row_partitions=2,
786          axis=0.5,
787          error_type=TypeError,
788          error_regex='axis must be an integer'),
789  ])
790  def testIsUniformRaises(self, lengths, num_row_partitions, axis, error_type,
791                          error_regex):
792    shape = DynamicRaggedShape.from_lengths(lengths)._with_num_row_partitions(
793        num_row_partitions)
794    with self.assertRaisesRegex(error_type, error_regex):
795      shape.is_uniform(axis)
796
797  @parameterized.parameters([
798      dict(lengths=[2, 3], num_row_partitions_a=0, num_row_partitions_b=1),
799      dict(
800          lengths=[2, (2, 3), 7, 4],
801          num_row_partitions_a=2,
802          num_row_partitions_b=1),
803      dict(
804          lengths=[3, (2, 0, 1), 5],
805          num_row_partitions_a=1,
806          num_row_partitions_b=2)
807  ])
808  def testWithNumRowPartitions(self, lengths, num_row_partitions_a,
809                               num_row_partitions_b):
810    shape = DynamicRaggedShape.from_lengths(lengths)
811    original_row_partitions = shape.num_row_partitions
812    shape_a = shape._with_num_row_partitions(num_row_partitions_a)
813    self.assertEqual(shape_a.num_row_partitions, num_row_partitions_a)
814    shape_b = shape_a._with_num_row_partitions(num_row_partitions_b)
815    self.assertEqual(shape_b.num_row_partitions, num_row_partitions_b)
816    actual = shape_b._with_num_row_partitions(original_row_partitions)
817    self.assertShapeEq(actual, shape)
818
819  @parameterized.parameters([
820      dict(
821          lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=-2, expected=7),
822      dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=0, expected=2),
823      dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=2, expected=7),
824      dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=3, expected=4),
825      dict(
826          lengths=[2, (2, 3), 7, 4, 3],
827          num_row_partitions=2,
828          axis=4,
829          expected=3),
830      dict(lengths=[3], axis=0, expected=3),
831      dict(lengths=[3, 4, 5], axis=0, expected=3),
832      dict(lengths=[3, 4, 5], axis=1, expected=4),
833      dict(lengths=[3, 4, 5], axis=2, expected=5),
834  ])
835  def testGetItem(self, lengths, axis, expected, num_row_partitions=None):
836    shape = DynamicRaggedShape.from_lengths(lengths)
837    if num_row_partitions is not None:
838      shape = shape._with_num_row_partitions(num_row_partitions)
839    actual = shape[axis]
840    self.assertAllEqual(actual, expected)
841
842  def testNumElements(self):
843    shape = DynamicRaggedShape.from_lengths([2, 3, 4,
844                                             5])._with_num_row_partitions(2)
845    self.assertAllEqual(shape._num_elements(), 120)
846
847  def test_to_row_partitions_from_lengths(self):
848    # Testing the test.
849    actual = _to_row_partitions_from_lengths([1, 2, 3])
850    expected = [
851        RowPartition.from_row_splits([0, 2]),
852        RowPartition.from_row_splits([0, 3, 6])
853    ]
854    self.assertRowPartitionEq(actual[0], expected[0])
855    self.assertRowPartitionEq(actual[1], expected[1])
856
857  @parameterized.parameters([
858      dict(
859          origin=b'x',
860          expected_lengths=[2, (1, 2)],
861          expected=[[b'x'], [b'x', b'x']]),
862      dict(
863          origin=b'x',
864          expected_lengths=[1, 1, 1],
865          expected_num_row_partitions=2,
866          expected=[[[b'x']]]),
867      dict(
868          origin=[b'a', b'b', b'c'],
869          expected_lengths=[3],
870          expected=[b'a', b'b', b'c']),
871      dict(
872          origin=[b'a', b'b', b'c'],
873          expected_lengths=[1, 1, 3],
874          expected_num_row_partitions=2,
875          expected=[[[b'a', b'b', b'c']]]),
876      dict(
877          origin=[[b'a', b'b', b'c'], [b'd', b'e', b'f']],
878          expected_lengths=[1, 2, 3],
879          expected_num_row_partitions=2,
880          expected=[[[b'a', b'b', b'c'], [b'd', b'e', b'f']]]),
881  ])
882  def testBroadcastTensorTo(self,
883                            origin,
884                            expected_lengths,
885                            expected,
886                            expected_num_row_partitions=None):
887    origin = constant_op.constant(origin)
888    expected_shape = DynamicRaggedShape.from_lengths(expected_lengths)
889    if expected_num_row_partitions is not None:
890      expected_shape = expected_shape._with_num_row_partitions(
891          expected_num_row_partitions)
892    expected = ragged_factory_ops.constant_value(expected)
893    actual = dynamic_ragged_shape.broadcast_to(origin, expected_shape)
894    self.assertAllEqual(actual, expected)
895
896  def testBroadcastFlatValues(self):
897    origin_lengths = [3, (1, 2, 1), 2, 2]
898    dest_lengths = [1, 1, 3, (1, 2, 1), 2, 2]
899    origin_values = constant_op.constant([
900        b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l',
901        b'm', b'n', b'o', b'p'
902    ])
903    origin_shape = DynamicRaggedShape.from_lengths(
904        origin_lengths)._with_num_row_partitions(3)
905    dest_shape = DynamicRaggedShape.from_lengths(
906        dest_lengths)._with_num_row_partitions(5)
907
908    broadcaster = dynamic_ragged_shape._get_broadcaster(origin_shape,
909                                                        dest_shape)
910
911    actual = broadcaster.broadcast_flat_values(origin_values)
912
913    self.assertAllEqual(origin_values, actual)
914
915  @parameterized.parameters([
916      dict(
917          origin_lengths=[3],
918          origin_values=[b'a', b'b', b'c'],
919          expected_lengths=[2],
920          expected_values=[[b'a', b'b', b'c'], [b'a', b'b', b'c']]),
921      dict(
922          origin_lengths=[3, (3, 2, 4)],
923          origin_values=[7, 4, 5, 6, 1, 2, 3, 7, 89],
924          expected_lengths=[3, (3, 2, 4)],
925          expected_values=[7, 4, 5, 6, 1, 2, 3, 7, 89]),
926      dict(
927          origin_lengths=[3, (3, 2, 4)],
928          origin_values=[7, 4, 5, 6, 1, 2, 3, 7, 89],
929          expected_lengths=[1, 3, (3, 2, 4)],
930          expected_values=[7, 4, 5, 6, 1, 2, 3, 7, 89]),
931      dict(
932          origin_lengths=[3, (3, 2, 4)],
933          origin_values=[7, 4, 5, 6, 1, 2, 3, 7, 89],
934          expected_lengths=[1, 1, 3, (3, 2, 4)],
935          expected_values=[7, 4, 5, 6, 1, 2, 3, 7, 89]),
936      # Broadcast [1, 2, (1, 2)] to [2, 2, (1, 2, 1, 2)]
937      dict(
938          origin_lengths=[1, 2, (1, 2)],
939          origin_values=[2, 3, 5],
940          expected_lengths=[2, 2, (1, 2, 1, 2)],
941          expected_values=[2, 3, 5, 2, 3, 5]),
942      # Broadcast [2, 1, (1, 2)] to [2, 2, (1, 1, 2, 2)] (NEW)
943      dict(
944          origin_lengths=[2, 1, (1, 2)],
945          origin_values=[2, 3, 5],
946          expected_lengths=[2, 2, (1, 1, 2, 2)],
947          expected_values=[2, 2, 3, 5, 3, 5]),
948      dict(
949          origin_lengths=[2, 1, 1],
950          origin_values=[2, 3],  # [[[2]], [[3]]]
951          expected_lengths=[2, 1, (3, 3)],
952          expected_values=[2, 2, 2, 3, 3, 3]),
953      dict(
954          origin_lengths=[3],
955          origin_values=[b'a', b'b', b'c'],
956          expected_lengths=[4, 2, 3],
957          expected_values=[
958              b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b',
959              b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a',
960              b'b', b'c'
961          ]),
962      dict(
963          origin_lengths=[2, 3],
964          origin_values=[b'a', b'b', b'c', b'a', b'b', b'c'],
965          expected_lengths=[4, 2, 3],
966          expected_values=[
967              b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b',
968              b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a',
969              b'b', b'c'
970          ]),
971      dict(
972          origin_lengths=[3, (1, 2, 1), 2, 2],
973          origin_values=[
974              b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k',
975              b'l', b'm', b'n', b'o', b'p'
976          ],
977          expected_lengths=[1, 1, 3, (1, 2, 1), 2, 2],
978          expected_values=[
979              b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k',
980              b'l', b'm', b'n', b'o', b'p'
981          ]),
982      dict(
983          origin_lengths=[3, (1, 2, 1), 2, 2],
984          origin_values=[7, 4, 5, 6, 1, 2, 3, 7, 7, 4, 5, 6, 1, 2, 3, 7],
985          expected_lengths=[1, 1, 3, (1, 2, 1), 2, 2],
986          expected_values=[7, 4, 5, 6, 1, 2, 3, 7, 7, 4, 5, 6, 1, 2, 3, 7],
987      ),
988  ])
989  def testBroadcastRaggedTo(self, origin_lengths, origin_values,
990                            expected_lengths, expected_values):
991    origin = _to_ragged_tensor_from_lengths(origin_values, origin_lengths)
992    expected = _to_ragged_tensor_from_lengths(expected_values, expected_lengths)
993    expected_shape = DynamicRaggedShape.from_tensor(expected)
994    actual = dynamic_ragged_shape.broadcast_to(origin, expected_shape)
995    self.assertAllEqual(actual, expected)
996
997  def testDynamicRaggedShapeFromTensor2(self):
998    raw_rt = [[[[7, 4], [5, 6]], [[1, 2], [3, 7]]], [[[7, 4], [5, 6]]],
999              [[[1, 2], [3, 7]]]]
1000    raw_rt = ragged_factory_ops.constant_value(raw_rt)
1001    actual_shape = DynamicRaggedShape.from_tensor(raw_rt)
1002    expected_shape = DynamicRaggedShape.from_lengths(
1003        [3, (2, 1, 1), 2, 2])._with_num_row_partitions(3)
1004    self.assertShapeEq(actual_shape, expected_shape)
1005
1006  # pylint: disable=g-long-lambda
1007  @parameterized.parameters([
1008      # A row partition as opposed to a list of row partitions.
1009      dict(
1010          row_partitions=lambda: RowPartition.from_row_splits([0, 2, 3]),
1011          inner_shape=lambda: [4],
1012          error_type=TypeError,
1013          error_regex='row_partitions should be'),
1014      # A list of lists of integers for row_partitions.
1015      dict(
1016          row_partitions=lambda: [[0, 2, 3]],
1017          inner_shape=lambda: [4],
1018          error_type=TypeError,
1019          error_regex='row_partitions contains'),
1020      # nvals and nrows don't match (3 != 6) statically
1021      dict(
1022          row_partitions=lambda: [  # pylint: disable=g-long-lambda
1023              RowPartition.from_value_rowids([0, 2, 4], nrows=5),
1024              RowPartition.from_value_rowids([0, 2, 5], nrows=6)
1025          ],
1026          inner_shape=lambda: [3],
1027          validate=False,
1028          error_type=ValueError,
1029          error_regex='RowPartitions in DynamicRaggedShape do not'),
1030      # nvals and inner_shape[0] don't match (3 != 6) statically
1031      dict(
1032          row_partitions=lambda: [
1033              RowPartition.from_value_rowids([0, 2, 4], nrows=5),
1034          ],
1035          inner_shape=lambda: [6],
1036          validate=False,
1037          error_type=ValueError,
1038          error_regex='Last row partition does not match inner_shape.'),
1039  ])
1040  def testConstructorRaisesStatic(self,
1041                                  row_partitions,
1042                                  inner_shape,
1043                                  error_type,
1044                                  error_regex,
1045                                  validate=False,
1046                                  dtype=None):
1047    row_partitions = row_partitions()
1048    inner_shape = inner_shape()
1049    with self.assertRaisesRegex(error_type, error_regex):
1050      DynamicRaggedShape(
1051          row_partitions, inner_shape, dtype=dtype, validate=validate)
1052
1053  def testConstructorStaticOK(self):
1054    row_partitions = [
1055        RowPartition.from_value_rowids([0, 2, 4], nrows=5),
1056        RowPartition.from_value_rowids([0, 1, 2], nrows=3)
1057    ]
1058    inner_shape = [3]
1059    rts = DynamicRaggedShape(row_partitions, inner_shape, validate=True)
1060    static_inner_shape = tensor_util.constant_value(rts.inner_shape)
1061    static_valid_rowids0 = tensor_util.constant_value(
1062        rts.row_partitions[0].value_rowids())
1063    static_valid_rowids1 = tensor_util.constant_value(
1064        rts.row_partitions[1].value_rowids())
1065    self.assertAllEqual(static_inner_shape, [3])
1066    self.assertAllEqual(static_valid_rowids0, [0, 2, 4])
1067    self.assertAllEqual(static_valid_rowids1, [0, 1, 2])
1068
1069  def testConstructorWithStaticInnerShape(self):
1070    row_partitions = [
1071        RowPartition.from_value_rowids([0, 2, 4], nrows=5),
1072        RowPartition.from_value_rowids([0, 1, 2], nrows=3)
1073    ]
1074    inner_shape = [3]
1075    rts = DynamicRaggedShape(row_partitions, inner_shape, validate=True,
1076                             static_inner_shape=[3])
1077    static_inner_shape = tensor_util.constant_value(rts.inner_shape)
1078    static_valid_rowids0 = tensor_util.constant_value(
1079        rts.row_partitions[0].value_rowids())
1080    static_valid_rowids1 = tensor_util.constant_value(
1081        rts.row_partitions[1].value_rowids())
1082    self.assertAllEqual(static_inner_shape, [3])
1083    self.assertAllEqual(static_valid_rowids0, [0, 2, 4])
1084    self.assertAllEqual(static_valid_rowids1, [0, 1, 2])
1085
1086  def testZeros(self):
1087    shape_x = DynamicRaggedShape.from_lengths([3, (1, 3, 2), 4])
1088    foo = ragged_array_ops.zeros(shape_x)
1089    self.assertShapeEq(shape_x, DynamicRaggedShape.from_tensor(foo))
1090    self.assertAllEqual(array_ops.zeros([6, 4]), foo.flat_values)
1091
1092  def testOnes(self):
1093    shape_x = DynamicRaggedShape.from_lengths([3, (1, 3, 2), 4])
1094    foo = ragged_array_ops.ones(shape_x)
1095    self.assertShapeEq(shape_x, DynamicRaggedShape.from_tensor(foo))
1096    self.assertAllEqual(array_ops.ones([6, 4]), foo.flat_values)
1097
1098  def testReshapeTensor(self):
1099    foo = array_ops.zeros([3, 2, 4])
1100    shape_b = DynamicRaggedShape.from_lengths([3, (3, 2, 1), 4])
1101    result = ragged_array_ops.ragged_reshape(foo, shape_b)
1102    self.assertShapeEq(shape_b, DynamicRaggedShape.from_tensor(result))
1103    self.assertAllEqual(array_ops.zeros([6, 4]), result.flat_values)
1104
1105  def test_reshape_ragged_tensor(self):
1106    shape_x = DynamicRaggedShape.from_lengths([3, (1, 3, 2), 4])
1107    foo = ragged_array_ops.zeros(shape_x)
1108    shape_b = DynamicRaggedShape.from_lengths([3, (3, 2, 1), 4])
1109    result = ragged_array_ops.ragged_reshape(foo, shape_b)
1110    self.assertShapeEq(shape_b, DynamicRaggedShape.from_tensor(result))
1111    self.assertAllEqual(array_ops.zeros([6, 4]), result.flat_values)
1112
1113  @parameterized.parameters([
1114      dict(
1115          lengths_a=[3, (1, 4, 2)],
1116          lengths_b=[3, (1, 4, 2)],
1117          lengths_e=[3, (1, 4, 2)]),
1118      dict(
1119          lengths_a=[1, 2, (1, 4)],
1120          lengths_b=[3, 2, (1, 4, 1, 4, 1, 4)],
1121          lengths_e=[3, 2, (1, 4, 1, 4, 1, 4)]),
1122      dict(
1123          lengths_a=[1, 1],
1124          num_row_partitions_a=1,
1125          lengths_b=[3, 5],
1126          num_row_partitions_b=1,
1127          lengths_e=[3, 5],
1128          num_row_partitions_e=1),
1129      dict(lengths_a=[1, 4, 5], lengths_b=[3, 1, 1], lengths_e=[3, 4, 5]),
1130      dict(lengths_a=[3], lengths_b=[4, 2, 1], lengths_e=[4, 2, 3]),
1131      dict(lengths_a=[2, 3], lengths_b=[4, 2, 1], lengths_e=[4, 2, 3]),
1132      # Outermost dimension-both partitioned
1133      # Also, neither has uniform_row_length
1134      dict(
1135          lengths_a=[2, (1, 3), 1],
1136          lengths_b=[2, (1, 3), (3, 4, 5, 6)],
1137          lengths_e=[2, (1, 3), (3, 4, 5, 6)]),
1138      # Outermost dimension-Only one is partitioned
1139      # Also, partitioned dimension doesn't have uniform_row_length
1140      dict(
1141          lengths_a=[2, 1, 5],
1142          lengths_b=[2, (1, 3), 5],
1143          num_row_partitions_b=2,
1144          lengths_e=[2, (1, 3), 5],
1145          num_row_partitions_e=2),
1146
1147      # Cover [5, R], [1, 5, R]
1148      dict(
1149          lengths_a=[5, (1, 2, 0, 3, 1)],
1150          lengths_b=[1, 5, (1, 2, 0, 3, 1)],
1151          lengths_e=[1, 5, (1, 2, 0, 3, 1)]),
1152      # When two uniform row lengths are equal
1153      dict(
1154          lengths_a=[1, 5],
1155          num_row_partitions_a=1,
1156          lengths_b=[3, 5],
1157          num_row_partitions_b=1,
1158          lengths_e=[3, 5],
1159          num_row_partitions_e=1),
1160      # Dense + Partitioned dimension has uniform_row_length
1161      # [1, 3, [5, 1, 6]] and DENSE [2, 1, 1] -> [2, 3, [5, 1, 6, 5, 1, 6]]
1162      dict(
1163          lengths_a=[1, 3, (5, 1, 6)],
1164          lengths_b=[2, 1, 1],
1165          lengths_e=[2, 3, (5, 1, 6, 5, 1, 6)]),
1166      # Both partitioned; one has uniform_row_length
1167      # (uniform_row_length [2,1,1]) and [2,[1,3],[3,4,5,6]]
1168      dict(
1169          lengths_a=[2, 1, 1],
1170          num_row_partitions_a=2,
1171          lengths_b=[2, (1, 3), (3, 4, 5, 6)],
1172          lengths_e=[2, (1, 3), (3, 4, 5, 6)]),
1173      # When broadcasting uniform_row_length to uniform_row_length.
1174      # Also, both have uniform_row_length
1175      dict(
1176          lengths_a=[3, 1, 5],
1177          num_row_partitions_a=2,
1178          lengths_b=[3, 4, 5],
1179          num_row_partitions_b=2,
1180          lengths_e=[3, 4, 5],
1181          num_row_partitions_e=2),
1182      # When broadcasting above a U_R_L
1183      # [2,1, 5] and [2, [1,3], 5] -> [2, [1,3], 5]
1184      dict(
1185          lengths_a=[2, 1, 5],
1186          num_row_partitions_a=2,
1187          lengths_b=[2, (1, 3), 5],
1188          num_row_partitions_b=2,
1189          lengths_e=[2, (1, 3), 5],
1190          num_row_partitions_e=2),
1191      # What if the larger-dimensional shape has uniform_row_length on the
1192      # matching dim, but has larger dimensions above
1193      # ([3,1,5],[15]) vs ([2,1],[2]))
1194      dict(
1195          lengths_a=[3, 1, 5],
1196          num_row_partitions_a=2,
1197          lengths_b=[2, 1],
1198          num_row_partitions_b=1,
1199          lengths_e=[3, 2, 5],
1200          num_row_partitions_e=2),
1201      # Inner non-ragged dimensions
1202      # Can delegate to dense broadcast operations.
1203      # Implementation detail: not testable.
1204      # ([2, [1,2]],[3,2,1]) and ([2,1],[2,1,3])
1205      dict(
1206          lengths_a=[2, (1, 2), 2, 1],
1207          lengths_b=[2, 1, 1, 3],
1208          num_row_partitions_b=1,
1209          lengths_e=[2, (1, 2), 2, 3],
1210      ),
1211  ])
1212  def testBroadcastDynamicShapeExtended(self,
1213                                        lengths_a,
1214                                        lengths_b,
1215                                        lengths_e,
1216                                        num_row_partitions_a=None,
1217                                        num_row_partitions_b=None,
1218                                        num_row_partitions_e=None):
1219    # This test is predicated on the fact that broadcast_to is correct.
1220    # Thus, it tests:
1221    # Whether the shape generated is correct.
1222    # Whether broadcasting is the same as broadcast_to.
1223    # Instead of specifying values, it just uses primes.
1224    shape_a = DynamicRaggedShape.from_lengths(lengths_a)
1225    if num_row_partitions_a is not None:
1226      shape_a = shape_a._with_num_row_partitions(num_row_partitions_a)
1227    shape_b = DynamicRaggedShape.from_lengths(lengths_b)
1228    if num_row_partitions_b is not None:
1229      shape_b = shape_b._with_num_row_partitions(num_row_partitions_b)
1230    shape_e = DynamicRaggedShape.from_lengths(lengths_e)
1231    if num_row_partitions_e is not None:
1232      shape_e = shape_e._with_num_row_partitions(num_row_partitions_e)
1233
1234    [actual, bc_a, bc_b
1235    ] = dynamic_ragged_shape.broadcast_dynamic_shape_extended(shape_a, shape_b)
1236    [actual_rev, bc_b_rev, bc_a_rev
1237    ] = dynamic_ragged_shape.broadcast_dynamic_shape_extended(shape_b, shape_a)
1238    self.assertShapeEq(actual, shape_e)
1239    self.assertShapeEq(actual_rev, shape_e)
1240
1241    rt_a = ragged_array_ops.ragged_reshape(
1242        _lowest_primes(_num_elements_of_lengths(lengths_a)), shape_a)
1243    bc_a_actual = bc_a.broadcast(rt_a)
1244    bc_a_actual_rev = bc_a_rev.broadcast(rt_a)
1245    bc_a_expected = dynamic_ragged_shape.broadcast_to(rt_a, shape_e)
1246    self.assertAllEqual(bc_a_expected, bc_a_actual)
1247    self.assertAllEqual(bc_a_expected, bc_a_actual_rev)
1248
1249    rt_b = ragged_array_ops.ragged_reshape(
1250        _lowest_primes(_num_elements_of_lengths(lengths_b)), shape_b)
1251    bc_b_expected = dynamic_ragged_shape.broadcast_to(rt_b, shape_e)
1252    bc_b_actual = bc_b.broadcast(rt_b)
1253    bc_b_actual_rev = bc_b_rev.broadcast(rt_b)
1254    self.assertAllEqual(bc_b_expected, bc_b_actual)
1255    self.assertAllEqual(bc_b_expected, bc_b_actual_rev)
1256
1257  @parameterized.parameters([
1258      dict(
1259          lengths=[3, (1, 4, 2)],
1260          dense_rank=1,
1261          lengths_e=[3, (1, 4, 2)],
1262      ),
1263      dict(
1264          lengths=[3, (1, 4, 2), 5],
1265          dense_rank=2,
1266          lengths_e=[3, (1, 4, 2), 5],
1267      ),
1268      dict(
1269          lengths=[3],
1270          dense_rank=1,
1271          lengths_e=[3],
1272      ),
1273  ])
1274  def testWithDenseRank(self, lengths, dense_rank, lengths_e):
1275    # Makes little sense with from_lengths/_with_num_row_partitions.
1276    original = DynamicRaggedShape.from_lengths(lengths)
1277    actual = original._with_inner_rank(dense_rank)
1278    self.assertAllEqual(actual.inner_rank, dense_rank)
1279    self.assertAllEqual(actual.static_lengths(), lengths_e)
1280
1281  @parameterized.parameters([
1282      dict(
1283          rps=[3, [1, 4, 2]],
1284          lengths_e=[3, (1, 4, 2)],
1285          num_row_partitions_e=1,
1286      ),
1287      dict(
1288          rps=[3, [1, 4, 2], 2],
1289          lengths_e=[3, (1, 4, 2), 2],
1290          num_row_partitions_e=2,
1291      ),
1292  ])
1293  def testFromRowPartitions(self, rps, lengths_e, num_row_partitions_e):
1294    rps = _to_row_partitions_from_lengths(rps)
1295    actual = DynamicRaggedShape.from_row_partitions(rps)
1296    expected = DynamicRaggedShape.from_lengths(
1297        lengths_e)._with_num_row_partitions(num_row_partitions_e)
1298    self.assertShapeEq(expected, actual)
1299
1300  def testFromRowPartitionsError(self):
1301    with self.assertRaisesRegex(ValueError, 'row_partitions cannot be empty'):
1302      DynamicRaggedShape.from_row_partitions([])
1303
1304  @parameterized.parameters([
1305      #=========================================================================
1306      # dimension[axis] is uniform inner; and row_lengths is a scalar
1307      #=========================================================================
1308      # shape: [BROADCAST(UNIFORM), UNIFORM, UNIFORM]
1309      dict(original_lengths=[1, 4, 5],
1310           broadcast_lengths=[3, 4, 5]),
1311      # shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
1312      dict(original_lengths=[3, 4, 1],
1313           broadcast_lengths=[3, 4, 5]),
1314      # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
1315      dict(original_lengths=[3, (3, 2, 8), 1],
1316           broadcast_lengths=[3, (3, 2, 8), 5]),
1317      # shape: [UNIFORM, RAGGED, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
1318      dict(original_lengths=[2, (2, 1), (3, 2, 8), 3, 4, 1],
1319           broadcast_lengths=[2, (2, 1), (3, 2, 8), 3, 4, 5]),
1320
1321      #=========================================================================
1322      # dimension[axis] is uniform inner; and row_lengths is a vector
1323      #=========================================================================
1324      # shape: [UNIFORM, BROADCAST(UNIFORM)]
1325      dict(original_lengths=[3, 1],
1326           broadcast_lengths=[3, (2, 0, 1)]),
1327      # shape: [UNIFORM, BROADCAST(UNIFORM), UNIFORM]
1328      dict(original_lengths=[3, 1, 5],
1329           broadcast_lengths=[3, (2, 0, 1), 5]),
1330
1331      # shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)]
1332      dict(original_lengths=[4, 3, 1],
1333           broadcast_lengths=[4, 3, (2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0)]),
1334
1335      # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)]
1336      dict(original_lengths=[2, (2, 1), 1],
1337           broadcast_lengths=[2, (2, 1), (2, 5, 3)]),
1338
1339      # shape: [UNIFORM, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM), UNIFORM]
1340      dict(original_lengths=[2, (2, 1), 3, 2, 1, 8],
1341           broadcast_lengths=[2, (2, 1), 3, 2, tuple(range(18)), 8]),
1342
1343      #=========================================================================
1344      # dimension[axis] is uniform partitioned; and row_lengths is a scalar
1345      #=========================================================================
1346      # shape: [BROADCAST(UNIFORM), RAGGED]
1347      dict(original_lengths=[1, (5,)],
1348           broadcast_lengths=[3, (5, 5, 5)]),
1349
1350      # shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED]
1351      dict(original_lengths=[1, 3, (3, 0, 2)],
1352           broadcast_lengths=[2, 3, (3, 0, 2, 3, 0, 2)]),
1353
1354      # shape: [BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM, UNIFORM]
1355      dict(original_lengths=[1, (3,), (3, 5, 2), 9, 4, 5],
1356           broadcast_lengths=[3, (3, 3, 3), (3, 5, 2, 3, 5, 2, 3, 5, 2),
1357                              9, 4, 5]),
1358
1359      # shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED, UNIFORM]
1360      dict(original_lengths=[1, 2, (2, 1), (3, 5, 2), 2],
1361           broadcast_lengths=[2, 2, (2, 1, 2, 1), (3, 5, 2, 3, 5, 2), 2]),
1362
1363      # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
1364      # This is wrong. should broadcast to [3, 2, (4, 4, 0, 0, 2, 2), 5]
1365      # dict(original_lengths=[3, 1, [4, 0, 2], 5],
1366      #      broadcast_lengths=[3, 2, [4, 0, 2, 4, 0, 2], 5]),
1367      dict(original_lengths=[3, 1, (4, 0, 2), 5],
1368           broadcast_lengths=[3, 2, (4, 4, 0, 0, 2, 2), 5]),
1369
1370      # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED]
1371      dict(original_lengths=[2, 3, (1, 2, 3, 4, 5, 6)],
1372           broadcast_lengths=[2, 3, (1, 2, 3, 4, 5, 6)]),
1373
1374      #=========================================================================
1375      # dimension[axis] is uniform partitioned; and row_lengths is a vector
1376      #=========================================================================
1377      # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM]
1378      dict(original_lengths=[
1379          3,                          # axis=0
1380          1,                          # axis=1 (broadcast)
1381          (3, 1, 2),                  # axis=2
1382          5],                         # axis=3
1383           broadcast_lengths=[
1384               3,                          # axis=0
1385               (4, 1, 2),                  # axis=1 (broadcast)
1386               (3, 3, 3, 3, 1, 2, 2),      # axis=2
1387               5]),                        # axis=3
1388
1389      # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, RAGGED]
1390      dict(original_lengths=[
1391          3,                                         # axis=0
1392          1,                                         # axis=1 (broadcast)
1393          (3, 1, 2),                                 # axis=2
1394          (3, 1, 4, 1, 5, 9)],                       # axis=3
1395           broadcast_lengths=[
1396               3,                                         # axis=0
1397               (2, 0, 3),                                 # axis=1 (broadcast)
1398               (3, 3, 2, 2, 2),                           # axis=2
1399               (3, 1, 4, 3, 1, 4, 5, 9, 5, 9, 5, 9)]),    # axis=3
1400
1401      # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM]
1402      dict(original_lengths=[
1403          3,                                         # axis=0
1404          (2, 0, 1),                                 # axis=1
1405          1,                                         # axis=2 (broadcast)
1406          (3, 2, 1),                                 # axis=3
1407          (1, 0, 1, 0, 2, 3),                        # axis=4
1408          5],                                        # axis=5
1409           broadcast_lengths=[
1410               3,                                         # axis=0
1411               (2, 0, 1),                                 # axis=2
1412               (4, 1, 2),                                 # axis=2 (broadcast)
1413               (3, 3, 3, 3, 2, 1, 1),                     # axis=3
1414               (1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0,    # axis=4
1415                2, 3, 3),
1416               5]),                                       # axis=5
1417      dict(original_lengths=[1, 1, 2, (2, 1)],
1418           broadcast_lengths=[2, 1, 2, (2, 1, 2, 1)]),
1419      dict(original_lengths=[2, 1, 2, (2, 1, 2, 1)],
1420           broadcast_lengths=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
1421      dict(original_lengths=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)],
1422           broadcast_lengths=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
1423      dict(original_lengths=[2, (2, 1), 2, 1],
1424           broadcast_lengths=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
1425  ])  # pyformat: disable
1426  def testBroadcastDimension(self, original_lengths, broadcast_lengths):
1427    """Tests broadcast_to on a single dimension."""
1428    original_rt = _to_prime_tensor_from_lengths(original_lengths)
1429    bcast_shape = DynamicRaggedShape.from_lengths(broadcast_lengths)
1430    result_rt = dynamic_ragged_shape.broadcast_to(original_rt, bcast_shape)
1431    result_shape = DynamicRaggedShape.from_tensor(result_rt)
1432
1433    self.assertShapeEq(bcast_shape, result_shape)
1434
1435  def testAsRowPartitions(self):
1436    my_shape = DynamicRaggedShape.from_lengths([3, (2, 0, 1), 5])
1437    rps = my_shape._as_row_partitions()
1438    self.assertLen(rps, 2)
1439
1440  def testAsRowPartitionsRaises(self):
1441    my_shape = DynamicRaggedShape.from_lengths([])
1442    with self.assertRaisesRegex(ValueError,
1443                                'rank must be >= 1 for _as_row_partitions'):
1444      my_shape._as_row_partitions()
1445
1446  def testToPrimeTensorFromDimSizes(self):
1447    """Tests the test utility."""
1448    original_lengths = [3, (3, 2, 8), 1]
1449    original_rt = _to_prime_tensor_from_lengths(original_lengths)
1450    expected_rt = _to_ragged_tensor_from_lengths(
1451        [[2], [3], [5], [7], [11], [13], [17], [19], [23], [29], [31], [37],
1452         [41]], [3, (3, 2, 8)])
1453
1454    self.assertAllEqual(expected_rt, original_rt)
1455
1456  @parameterized.parameters([
1457      # Broadcast scalar
1458      dict(x_dims=[], y_dims=[], expected_dims=[]),
1459      dict(x_dims=[], y_dims=[2], expected_dims=[2]),
1460      dict(x_dims=[], y_dims=[2, 3], expected_dims=[2, 3]),
1461      dict(
1462          x_dims=[],
1463          y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
1464          expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
1465      # Broadcast vector
1466      dict(x_dims=[3], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
1467      dict(x_dims=[1], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]),
1468      dict(x_dims=[3], y_dims=[4, 2, 1], expected_dims=[4, 2, 3]),
1469      dict(
1470          x_dims=[3], y_dims=[3, (2, 3, 1), 1], expected_dims=[3, (2, 3, 1),
1471                                                               3]),
1472      dict(x_dims=[1], y_dims=[3, (2, 1, 3)], expected_dims=[3, (2, 1, 3)]),
1473      dict(
1474          x_dims=[1], y_dims=[3, (2, 1, 3), 8], expected_dims=[3, (2, 1, 3),
1475                                                               8]),
1476      dict(
1477          x_dims=[1],
1478          y_dims=[2, (2, 3), (5, 7, 2, 0, 9)],
1479          expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]),
1480      # Mixed broadcasting
1481      dict(
1482          x_dims=[
1483              1,  # axis=0
1484              3,  # axis=1
1485              (3, 0, 2),  # axis=2
1486              1,  # axis=3
1487              2,  # axis=4
1488          ],
1489          y_dims=[
1490              2,  # axis=0
1491              1,  # axis=1
1492              1,  # axis=2
1493              (7, 2),  # axis=3
1494              1,  # axis=4
1495          ],
1496          expected_dims=[
1497              2,  # axis=0
1498              3,  # axis=1
1499              (3, 0, 2, 3, 0, 2),  # axis=2
1500              (7, 7, 7, 7, 7, 2, 2, 2, 2, 2),  # axis=3
1501              2,  # axis=4
1502          ]),
1503      dict(
1504          x_dims=[2, (2, 1), 2, 1],
1505          y_dims=[1, 1, 2, (2, 1)],
1506          expected_dims=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]),
1507  ])
1508  def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims):
1509    shape_a = DynamicRaggedShape.from_lengths(x_dims)
1510    shape_b = DynamicRaggedShape.from_lengths(y_dims)
1511    shape_e = DynamicRaggedShape.from_lengths(expected_dims)
1512    [actual, bc_a, bc_b
1513    ] = dynamic_ragged_shape.broadcast_dynamic_shape_extended(shape_a, shape_b)
1514    [actual_rev, bc_b_rev, bc_a_rev
1515    ] = dynamic_ragged_shape.broadcast_dynamic_shape_extended(shape_b, shape_a)
1516    self.assertShapeEq(actual, shape_e)
1517    self.assertShapeEq(actual_rev, shape_e)
1518
1519    rt_a = _to_prime_tensor_from_lengths(x_dims)
1520    bc_a_actual = bc_a.broadcast(rt_a)
1521    bc_a_actual_rev = bc_a_rev.broadcast(rt_a)
1522    bc_a_expected = dynamic_ragged_shape.broadcast_to(rt_a, shape_e)
1523    self.assertAllEqual(bc_a_expected, bc_a_actual)
1524    self.assertAllEqual(bc_a_expected, bc_a_actual_rev)
1525
1526    rt_b = _to_prime_tensor_from_lengths(y_dims)
1527    bc_b_expected = dynamic_ragged_shape.broadcast_to(rt_b, shape_e)
1528    bc_b_actual = bc_b.broadcast(rt_b)
1529    bc_b_actual_rev = bc_b_rev.broadcast(rt_b)
1530    self.assertAllEqual(bc_b_expected, bc_b_actual)
1531    self.assertAllEqual(bc_b_expected, bc_b_actual_rev)
1532
1533    # This just wraps broadcast_dynamic_shape_extended, so nothing
1534    # deeper is required.
1535    result1 = dynamic_ragged_shape.broadcast_dynamic_shape(shape_a, shape_b)
1536    self.assertShapeEq(shape_e, result1)
1537
1538    # Again, just a wrapper.
1539    result2 = ragged_array_ops.broadcast_dynamic_shape(shape_a, shape_b)
1540    self.assertShapeEq(shape_e, result2)
1541
1542  def testBroadcastDynamicShapeFirstLayer(self):
1543    a_0 = constant_op.constant(1, dtypes.int64)
1544    b_0 = constant_op.constant(3, dtypes.int64)
1545    [a_layer, b_layer
1546    ] = dynamic_ragged_shape._broadcast_dynamic_shape_first_layer(a_0, b_0)
1547    expected_a_layer = _LayerBroadcaster.from_gather_index([0, 0, 0])
1548    expected_b_layer = _LayerBroadcaster.from_gather_index([0, 1, 2])
1549    self.assertLayerBroadcasterEq(expected_a_layer, a_layer)
1550    self.assertLayerBroadcasterEq(expected_b_layer, b_layer)
1551
1552  def testBroadcastDynamicShapeNextLayer(self):
1553    a_1 = RowPartition.from_uniform_row_length(
1554        1, nvals=1, nrows=1, dtype_hint=dtypes.int64)
1555    b_1 = RowPartition.from_row_lengths([2, 1, 3], dtype_hint=dtypes.int64)
1556    ac_0 = _LayerBroadcaster.from_gather_index(
1557        constant_op.constant([0, 0, 0], dtype=dtypes.int64))
1558    bc_0 = _LayerBroadcaster.from_gather_index(
1559        constant_op.constant([0, 1, 2], dtype=dtypes.int64))
1560    dynamic_ragged_shape._broadcast_dynamic_shape_next_layer_half_ragged(
1561        ac_0, bc_0, a_1, b_1)
1562
1563  def testBroadcastDynamicShapeRaisesLeft(self):
1564    shape = DynamicRaggedShape.from_tensor(constant_op.constant([1, 2, 3]))
1565    with self.assertRaisesRegex(TypeError, 'shape_x must be'):
1566      dynamic_ragged_shape.broadcast_dynamic_shape(1, shape)
1567
1568  def testBroadcastDynamicShapeRaisesRight(self):
1569    shape = DynamicRaggedShape.from_tensor(constant_op.constant([1, 2, 3]))
1570    with self.assertRaisesRegex(TypeError, 'shape_y must be'):
1571      dynamic_ragged_shape.broadcast_dynamic_shape(shape, 1)
1572
1573  def testBroadcastToRaises(self):
1574    rt = constant_op.constant([1, 2, 3])
1575    with self.assertRaisesRegex(TypeError, 'shape must be'):
1576      dynamic_ragged_shape.broadcast_to(rt, 1)
1577
1578  @parameterized.parameters([
1579      dict(
1580          x=[[10], [20], [30]],  # shape=[3, 1]
1581          lengths=[3, 2],
1582          expected=[[10, 10], [20, 20], [30, 30]]),
1583      dict(
1584          x=[[10], [20], [30]],  # shape=[3, 1]
1585          lengths=[3, (3, 0, 2)],
1586          expected=ragged_factory_ops.constant_value(
1587              [[10, 10, 10], [], [30, 30]], dtype=np.int32)),
1588      dict(
1589          x=[[[1, 2, 3]], [[4, 5, 6]]],  # shape = [2, 1, 3]
1590          lengths=[2, (2, 3), 3],
1591          expected=ragged_factory_ops.constant_value(
1592              [[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]],
1593              dtype=np.int32,
1594              ragged_rank=1)),
1595      dict(
1596          x=[[[1]], [[2]]],  # shape = [2, 1, 1]
1597          lengths=[2, (2, 3), (0, 2, 1, 2, 0)],
1598          expected=ragged_factory_ops.constant_value(
1599              [[[], [1, 1]], [[2], [2, 2], []]], dtype=np.int32,
1600              ragged_rank=2)),
1601      dict(
1602          x=10,
1603          lengths=[3, (3, 0, 2)],
1604          expected=ragged_factory_ops.constant_value([[10, 10, 10], [],
1605                                                      [10, 10]])),
1606      dict(
1607          x=ragged_factory_ops.constant_value([[[1], [2]], [[3]]],
1608                                              ragged_rank=1),
1609          lengths=[2, (2, 1), 2],
1610          expected=ragged_factory_ops.constant_value(
1611              [[[1, 1], [2, 2]], [[3, 3]]], ragged_rank=1)),
1612  ])
1613  def testRaggedBroadcastTo(self, x, lengths, expected):
1614    shape = DynamicRaggedShape.from_lengths(lengths)
1615    result = dynamic_ragged_shape.broadcast_to(x, shape)
1616    self.assertEqual(
1617        getattr(result, 'num_row_partitions', 0),
1618        getattr(expected, 'num_row_partitions', 0))
1619    self.assertAllEqual(result, expected)
1620
1621    # broadcast_to just calls dynamic_ragged_shape.broadcast_to, so
1622    # this should be sufficient.
1623    result2 = ragged_array_ops.broadcast_to(x, shape)
1624    self.assertAllEqual(result2, expected)
1625
1626  @parameterized.parameters([
1627      dict(
1628          doc='x.shape=[3, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
1629          x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]],
1630                                              dtype=np.int32),
1631          y=[[10], [20], [30]],
1632          expected=ragged_factory_ops.constant_value([[11, 12, 13], [],
1633                                                      [34, 35]])),
1634      dict(
1635          doc='x.shape=[3, (D1)]; y.shape=[]; bcast.shape=[3, (D1)]',
1636          x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]],
1637                                              dtype=np.int32),
1638          y=10,
1639          expected=ragged_factory_ops.constant_value([[11, 12, 13], [],
1640                                                      [14, 15]])),
1641      dict(
1642          doc='x.shape=[1, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
1643          x=ragged_factory_ops.constant_value([[1, 2, 3]], dtype=np.int32),
1644          y=[[10], [20], [30]],
1645          expected=ragged_factory_ops.constant_value(
1646              [[11, 12, 13], [21, 22, 23], [31, 32, 33]], dtype=np.int32)),
1647      dict(
1648          doc=('x.shape=[2, (D1), 1]; y.shape=[1, (D2)]; '
1649               'bcast.shape=[2, (D1), (D2)]'),
1650          x=ragged_factory_ops.constant_value([[[1], [2], [3]], [[4]]],
1651                                              ragged_rank=1),
1652          y=ragged_factory_ops.constant_value([[10, 20, 30]]),
1653          expected=ragged_factory_ops.constant_value([[[11, 21,
1654                                                        31], [12, 22, 32],
1655                                                       [13, 23, 33]],
1656                                                      [[14, 24, 34]]])),
1657      dict(
1658          doc=('x.shape=[2, (D1), 1]; y.shape=[1, 1, 4]; '
1659               'bcast.shape=[2, (D1), 4]'),
1660          x=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
1661                                              ragged_rank=1),
1662          y=[[[1, 2, 3, 4]]],
1663          expected=ragged_factory_ops.constant_value(
1664              [[[11, 12, 13, 14], [21, 22, 23, 24]], [[31, 32, 33, 34]]],
1665              ragged_rank=1)),
1666      dict(
1667          doc=('x.shape=[2, (D1), 2, 1]; y.shape=[2, (D2)]; '
1668               'bcast.shape=[2, (D1), (2), (D2)'),
1669          x=ragged_factory_ops.constant_value(
1670              [[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1),
1671          y=ragged_factory_ops.constant_value([[10, 20], [30]]),
1672          expected=ragged_factory_ops.constant_value([[[[11, 21], [32]],
1673                                                       [[13, 23], [34]]],
1674                                                      [[[15, 25], [36]]]])),
1675  ])
1676  def testRaggedAddWithBroadcasting(self, x, y, expected, doc):
1677    expected_rrank = getattr(expected, 'num_row_partitions', 0)
1678    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
1679    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
1680    result = x + y
1681    result_rrank = getattr(result, 'num_row_partitions', 0)
1682    self.assertEqual(expected_rrank, result_rrank)
1683    if hasattr(expected, 'tolist'):
1684      expected = expected.tolist()
1685    self.assertAllEqual(result, expected)
1686
1687  @parameterized.parameters([
1688      dict(lengths_a=[3, (1, 4, 2)], new_impl=True, op_max=10),  # Actual ops: 5
1689      dict(lengths_a=[3, (1, 4, 2)], new_impl=False, op_max=300),
1690  ])
1691  def testAddSelf(self, lengths_a, new_impl, op_max, num_row_partitions_a=None):
1692    if context.executing_eagerly():
1693      return
1694    shape_a0 = DynamicRaggedShape.from_lengths(
1695        lengths_a, num_row_partitions=num_row_partitions_a)
1696    rt_a = ragged_array_ops.ragged_reshape(
1697        _lowest_primes(_num_elements_of_lengths(lengths_a)), shape_a0)
1698    rt_b = rt_a
1699    g = rt_a.flat_values.graph if ragged_tensor.is_ragged(rt_a) else rt_a.graph
1700    nodes_at_a = len(g.as_graph_def().node)
1701    if new_impl:
1702      dynamic_ragged_shape.ragged_binary_elementwise_op_impl(
1703          gen_math_ops.add_v2, rt_a, rt_b)
1704      nodes_at_b = len(g.as_graph_def().node)
1705      node_delta = nodes_at_b - nodes_at_a
1706      self.assertLessEqual(node_delta, op_max)
1707    else:
1708      if isinstance(rt_a, RaggedTensor):
1709        rt_a = rt_a.with_row_splits_dtype(dtypes.int32)
1710      rt_b = rt_a
1711      nodes_at_b = len(g.as_graph_def().node)
1712      rt_a + rt_b  # pylint: disable=pointless-statement
1713      nodes_at_d = len(g.as_graph_def().node)
1714      node_delta = nodes_at_d - nodes_at_b
1715      self.assertLessEqual(node_delta, op_max)
1716
1717  def testAndSelfBool(self):
1718    if context.executing_eagerly():
1719      return
1720    values = constant_op.constant([True, False, True, True, True])
1721    rt_a = RaggedTensor.from_row_splits(values, [0, 3, 3, 5])
1722    result = dynamic_ragged_shape.ragged_binary_elementwise_op_impl(
1723        gen_math_ops.logical_and, rt_a, rt_a)
1724
1725    expected_values = values
1726    expected = RaggedTensor.from_row_splits(expected_values, [0, 3, 3, 5])
1727
1728    self.assertAllEqual(result, expected)
1729
1730  def testEquals(self):
1731    if context.executing_eagerly():
1732      return
1733
1734    rt_a = ragged_factory_ops.constant([[3, 1, 3], [3]])
1735    b = constant_op.constant(3)
1736    rt_expected = ragged_factory_ops.constant([[True, False, True], [True]])
1737
1738    result = dynamic_ragged_shape.ragged_binary_elementwise_op_impl(
1739        math_ops.equal, rt_a, b)
1740    self.assertAllEqual(result, rt_expected)
1741
1742  def testEquals2(self):
1743    splits = constant_op.constant([0, 1])
1744    a = RaggedTensor.from_row_splits([[1, 2]], splits)
1745    b = RaggedTensor.from_row_splits([[3, 4, 5]], splits)
1746    self.assertIs(a == b, False)
1747
1748  def testEquals3(self):
1749    a = RaggedTensor.from_row_splits([[1, 2]], [0, 1])
1750    b = RaggedTensor.from_row_splits([[3, 4, 5]], [0, 1])
1751    self.assertIs(a == b, False)
1752
1753  @parameterized.parameters([
1754      dict(
1755          lengths_a=[3, (1, 4, 2)], lengths_b=[], new_impl=True,
1756          max_num_ops=5),  # Actual ops: 1
1757      dict(
1758          lengths_a=[3, (1, 4, 2), 3, 2],
1759          lengths_b=[3, 2],
1760          new_impl=True,
1761          max_num_ops=5),  # Actual ops: 1
1762      dict(
1763          lengths_a=[3, (1, 4, 2)], lengths_b=[], new_impl=False,
1764          max_num_ops=5),  # Actual ops: 1
1765      dict(
1766          lengths_a=[3, (1, 4, 2), 3, 2],
1767          lengths_b=[3, 2],
1768          new_impl=False,
1769          max_num_ops=5),  # Actual ops: 1
1770  ])
1771  def testAdd(self,
1772              lengths_a,
1773              lengths_b,
1774              new_impl,
1775              max_num_ops,
1776              num_row_partitions_a=None,
1777              num_row_partitions_b=None):
1778    if context.executing_eagerly():
1779      return
1780
1781    shape_a0 = DynamicRaggedShape.from_lengths(
1782        lengths_a, num_row_partitions=num_row_partitions_a)
1783    shape_b0 = DynamicRaggedShape.from_lengths(
1784        lengths_b, num_row_partitions=num_row_partitions_b)
1785    rt_a = ragged_array_ops.ragged_reshape(
1786        _lowest_primes(_num_elements_of_lengths(lengths_a)), shape_a0)
1787    rt_b = ragged_array_ops.ragged_reshape(
1788        _lowest_primes(_num_elements_of_lengths(lengths_b)), shape_b0)
1789    g = rt_a.flat_values.graph if ragged_tensor.is_ragged(rt_a) else rt_a.graph
1790
1791    nodes_at_a = len(g.as_graph_def().node)
1792    if new_impl:
1793      dynamic_ragged_shape.ragged_binary_elementwise_op_impl(
1794          gen_math_ops.add_v2,
1795          rt_a,
1796          rt_b)
1797      nodes_at_b = len(g.as_graph_def().node)
1798      num_nodes = nodes_at_b - nodes_at_a
1799      self.assertLessEqual(num_nodes, max_num_ops)
1800    else:
1801      if isinstance(rt_a, RaggedTensor):
1802        rt_a = rt_a.with_row_splits_dtype(dtypes.int32)
1803      if isinstance(rt_b, RaggedTensor):
1804        rt_b = rt_b.with_row_splits_dtype(dtypes.int32)
1805      nodes_at_b = len(g.as_graph_def().node)
1806      rt_a + rt_b  # pylint: disable=pointless-statement
1807      nodes_at_d = len(g.as_graph_def().node)
1808      num_nodes = nodes_at_d - nodes_at_b
1809
1810  @parameterized.parameters([
1811      dict(
1812          lengths_a=[3, (1, 4, 2)], lengths_b=[],
1813          shape_e=[3, None], new_impl=False),
1814      dict(
1815          lengths_a=[3, (1, 4, 2)], lengths_b=[],
1816          shape_e=[3, None], new_impl=True),
1817      dict(
1818          lengths_a=[5, (1, 4, 2, 1, 3), 3],
1819          lengths_b=[5, 1, 3],
1820          shape_e=[5, None, 3], new_impl=False),
1821      dict(
1822          lengths_a=[5, (1, 4, 2, 1, 3), 3],
1823          lengths_b=[5, 1, 3],
1824          shape_e=[5, None, 3], new_impl=True),
1825      dict(
1826          lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3],
1827          lengths_b=[3, 2, 1, 3],
1828          shape_e=[3, 2, None, 3], new_impl=False),
1829      dict(
1830          lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3],
1831          lengths_b=[3, 2, 1, 3],
1832          shape_e=[3, 2, None, 3],
1833          new_impl=True),
1834      dict(
1835          lengths_a=[3, (1, 4, 2)], lengths_b=[3, 1],
1836          shape_e=[3, None], new_impl=False),
1837      dict(
1838          lengths_a=[3, (1, 4, 2)], lengths_b=[3, 1],
1839          shape_e=[3, None], new_impl=True),
1840
1841  ])
1842  def testAddShape(self,
1843                   lengths_a,
1844                   lengths_b,
1845                   shape_e,
1846                   new_impl=False,
1847                   num_row_partitions_a=None,
1848                   num_row_partitions_b=None):
1849    if context.executing_eagerly():
1850      return
1851    shape_a = DynamicRaggedShape.from_lengths(
1852        lengths_a, num_row_partitions=num_row_partitions_a)
1853    shape_b = DynamicRaggedShape.from_lengths(
1854        lengths_b, num_row_partitions=num_row_partitions_b)
1855    rt_a = ragged_array_ops.ragged_reshape(
1856        _lowest_primes(_num_elements_of_lengths(lengths_a)), shape_a)
1857    rt_b = ragged_array_ops.ragged_reshape(
1858        _lowest_primes(_num_elements_of_lengths(lengths_b)), shape_b)
1859    if new_impl:
1860      result = dynamic_ragged_shape.ragged_binary_elementwise_op_impl(
1861          math_ops.add, rt_a, rt_b)
1862      shape_e = tensor_shape.TensorShape(shape_e)
1863      self.assertEqual(shape_e.as_list(), result.shape.as_list())
1864    else:
1865      if isinstance(rt_a, RaggedTensor):
1866        rt_a = rt_a.with_row_splits_dtype(dtypes.int32)
1867      if isinstance(rt_b, RaggedTensor):
1868        rt_b = rt_b.with_row_splits_dtype(dtypes.int32)
1869      result = rt_a + rt_b
1870      shape_e = tensor_shape.TensorShape(shape_e)
1871      self.assertEqual(shape_e.as_list(), result.shape.as_list())
1872
1873  @parameterized.parameters([
1874      dict(
1875          lengths_a=[3, (1, 4, 2)], lengths_b=[],
1876          shape_e=[3, (1, 4, 2)]),
1877      dict(
1878          lengths_a=[5], lengths_b=[1],
1879          shape_e=[5]),
1880      dict(
1881          lengths_a=[5, (1, 4, 2, 1, 3), 3],
1882          lengths_b=[5, 1, 3],
1883          shape_e=[5, None, 3]),
1884      dict(
1885          lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3],
1886          lengths_b=[3, 2, 1, 3],
1887          shape_e=[3, 2, None, 3]),
1888      dict(lengths_a=[3, (1, 4, 2)], lengths_b=[3, 1], shape_e=[3, None]),
1889      dict(lengths_a=[5, 1, 3], lengths_b=[2, 3], shape_e=[5, 2, 3]),
1890      dict(lengths_a=[5, 1, (3, 2, 4, 1, 3)], lengths_b=[2, 1],
1891           shape_e=[5, 2, None]),
1892      dict(lengths_a=[5, 4, 1, 3], lengths_b=[2, 1], shape_e=[5, 4, 2, 3]),
1893  ])
1894  def testBroadcastDynamicShapeStatic(self,
1895                                      lengths_a,
1896                                      lengths_b,
1897                                      shape_e,
1898                                      num_row_partitions_a=None,
1899                                      num_row_partitions_b=None):
1900    if context.executing_eagerly():
1901      return
1902    shape_a = DynamicRaggedShape.from_lengths(
1903        lengths_a, num_row_partitions=num_row_partitions_a)
1904    shape_b = DynamicRaggedShape.from_lengths(
1905        lengths_b, num_row_partitions=num_row_partitions_b)
1906
1907    result = dynamic_ragged_shape.broadcast_dynamic_shape(shape_a, shape_b)
1908    result_shape = result._to_tensor_shape()
1909
1910    tensor_shape_e = [None if isinstance(x, tuple) else x for x in shape_e]
1911    self.assertEqual(shape_e, result.static_lengths())
1912    self.assertEqual(tensor_shape_e, result_shape.as_list())
1913
1914  def testBroadcastDynamicShapePartiallyKnown(self):
1915    if context.executing_eagerly():
1916      return
1917    @def_function.function(
1918        input_signature=[tensor_spec.TensorSpec(None, dtypes.int64)])
1919    def fun(x):
1920      shape_a = DynamicRaggedShape([], array_ops.stack([5, x, 3]))
1921      shape_b = DynamicRaggedShape.from_lengths([1, 3], dtype=dtypes.int64)
1922      result = dynamic_ragged_shape.broadcast_dynamic_shape(shape_a, shape_b)
1923      self.assertAllEqual([5, None, 3], result.static_lengths())
1924    fun(constant_op.constant(2, dtype=dtypes.int64))
1925
1926  def testBroadcastDynamicShapePartiallyKnownNiceToHave(self):
1927    if context.executing_eagerly():
1928      return
1929    @def_function.function(
1930        input_signature=[tensor_spec.TensorSpec(None, dtypes.int64)])
1931    def fun(x):
1932      shape_a = DynamicRaggedShape([], array_ops.stack([5, x, 3]))
1933      shape_b = DynamicRaggedShape.from_lengths([2, 3], dtype=dtypes.int64)
1934      result = dynamic_ragged_shape.broadcast_dynamic_shape(shape_a, shape_b)
1935      self.assertAllEqual([5, 2, 3], result.static_lengths())
1936    fun(constant_op.constant(2, dtype=dtypes.int64))
1937
1938  def testFromRowPartitionsStatic(self):
1939    if context.executing_eagerly():
1940      return
1941    rp = RowPartition.from_row_lengths([4, 2, 3])
1942    result = DynamicRaggedShape.from_row_partitions([rp])
1943    self.assertEqual([3, (4, 2, 3)], result.static_lengths())
1944
1945  @parameterized.parameters([
1946      dict(
1947          lengths_a=[3, (1, 4, 2)], dim=0,
1948          expected=3),
1949      dict(
1950          lengths_a=[5], dim=0,
1951          expected=5),
1952      dict(
1953          lengths_a=[5, (1, 4, 2, 1, 3), 3],
1954          dim=0,
1955          expected=5),
1956      dict(
1957          lengths_a=[5, (1, 4, 2, 1, 3), 3],
1958          dim=2,
1959          expected=3),
1960      dict(
1961          lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3],
1962          dim=1,
1963          expected=2),
1964      dict(lengths_a=[5, 1, 3], dim=0, expected=5),
1965  ])
1966  def testDimStatic(self, lengths_a, dim, expected):
1967    if context.executing_eagerly():
1968      return
1969    shape_a = DynamicRaggedShape.from_lengths(lengths_a)
1970    result = tensor_util.constant_value(shape_a[dim])
1971    self.assertEqual(result, expected)
1972
1973  @parameterized.parameters([
1974      dict(
1975          lengths_a=[5, (1, 4, 2, 1, 3), 3],
1976          shape_e=[5, (1, 4, 2, 1, 3), 3],
1977          new_num_row_partitions=2),  # Fails
1978      dict(
1979          lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3],
1980          shape_e=[3, 2, (1, 4, 2, 1, 3, 1), 3],
1981          new_num_row_partitions=3),  # Fails
1982  ])
1983  def testNumRowPartitionShapeStatic(self,
1984                                     lengths_a,
1985                                     shape_e,
1986                                     new_num_row_partitions,
1987                                     num_row_partitions_a=None):
1988    if context.executing_eagerly():
1989      return
1990    shape_a = DynamicRaggedShape.from_lengths(
1991        lengths_a, num_row_partitions=num_row_partitions_a)
1992    result = shape_a._with_num_row_partitions(new_num_row_partitions)
1993    self.assertEqual(shape_e, result.static_lengths())
1994
1995  @parameterized.parameters([
1996      dict(lengths_a=[5, (1, 4, 2, 1, 3), 3]),
1997      dict(lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3]),
1998  ])
1999  def testFromLengthsNRowsStatic(self, lengths_a):
2000    if context.executing_eagerly():
2001      return
2002    shape_a = DynamicRaggedShape.from_lengths(lengths_a)
2003    for rp in shape_a.row_partitions:
2004      actual = tensor_util.constant_value(rp.nrows())
2005      self.assertIsNotNone(actual, 'Failed on ' + str(rp))
2006
2007  @parameterized.parameters([
2008      dict(
2009          lengths_a=[5, (1, 4, 2, 1, 3), 3], inner_shape=[33],
2010          new_inner_rank=1),
2011      dict(
2012          lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3],
2013          inner_shape=[36],
2014          new_inner_rank=1),
2015      dict(
2016          lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3, 4],
2017          inner_shape=[36, 4],
2018          new_inner_rank=2),
2019  ])
2020  def testAltInnerShapeStatic(self,
2021                              lengths_a,
2022                              inner_shape,
2023                              new_inner_rank,
2024                              num_row_partitions_a=None):
2025    if context.executing_eagerly():
2026      return
2027    shape_a = DynamicRaggedShape.from_lengths(
2028        lengths_a, num_row_partitions=num_row_partitions_a)
2029    result = shape_a._alt_inner_shape(new_inner_rank)
2030    result_static = tensor_util.constant_value_as_shape(result)
2031    self.assertEqual(inner_shape, result_static.as_list())
2032
2033  @parameterized.parameters([
2034      dict(
2035          lengths=[3, (1, 4, 2)],
2036          shape_e=[3, None]),
2037      dict(
2038          lengths=[3, (1, 4, 2)],
2039          shape_e=[3, None]),
2040      dict(
2041          lengths=[5, (1, 4, 2, 1, 3), 3],
2042          shape_e=[5, None, 3]),
2043      dict(
2044          lengths=[5, (1, 4, 2, 1, 3), 3],
2045          shape_e=[5, None, 3]),
2046      dict(
2047          lengths=[3, 2, (1, 4, 2, 1, 3, 1), 3],
2048          shape_e=[3, 2, None, 3]),
2049      dict(
2050          lengths=[3, 2, (1, 4, 2, 1, 3, 1), 3],
2051          shape_e=[3, 2, None, 3]),
2052  ])
2053  def testStaticShape(self,
2054                      lengths,
2055                      shape_e,
2056                      num_row_partitions=None):
2057    # Testing the shape has enough information.
2058    # In particular, any uniform_row_length should be reproduced.
2059    if context.executing_eagerly():
2060      return
2061    shape = DynamicRaggedShape.from_lengths(
2062        lengths, num_row_partitions=num_row_partitions)
2063    rt_a = ragged_array_ops.ragged_reshape(
2064        _lowest_primes(_num_elements_of_lengths(lengths)), shape)
2065    shape_e = tensor_shape.TensorShape(shape_e)
2066    self.assertEqual(shape_e.as_list(), rt_a.shape.as_list())
2067
2068  @parameterized.parameters([
2069      dict(
2070          lengths=[5, (1, 4, 2, 1, 3), 3],
2071          shape_e=[5, (1, 4, 2, 1, 3), 3]),
2072      dict(
2073          lengths=[3, 2, (1, 4, 2, 1, 3, 1), 3],
2074          shape_e=[3, 2, (1, 4, 2, 1, 3, 1), 3]),
2075  ])
2076  def testWithNumRowPartitionsStatic(self,
2077                                     lengths,
2078                                     shape_e,
2079                                     num_row_partitions=None):
2080    # Note that this test loses the later static values.
2081    if context.executing_eagerly():
2082      return
2083    shape = DynamicRaggedShape.from_lengths(
2084        lengths, num_row_partitions=num_row_partitions)
2085    shape_b = shape._with_num_row_partitions(shape.rank - 1)
2086    self.assertEqual(shape_e, shape_b.static_lengths())
2087
2088  def testWithNumRowPartitionsStaticAlt(self):
2089    # Note that this test loses the later static values.
2090    if context.executing_eagerly():
2091      return
2092    shape = DynamicRaggedShape.from_lengths(
2093        [5, 2, 3], num_row_partitions=2)
2094    shape_b = shape._with_num_row_partitions(0)
2095    self.assertEqual([5, 2, 3], shape_b.static_lengths())
2096
2097  def testWithNumRowPartitionsDType(self):
2098    # Note that this test loses the later static values.
2099    shape = DynamicRaggedShape([], constant_op.constant([5, 2, 3],
2100                                                        dtype=dtypes.int32))
2101    self.assertEqual(shape.dtype, dtypes.int32)
2102
2103    result = shape._with_num_row_partitions(2)
2104    self.assertEqual(result.dtype, dtypes.int32)
2105
2106  def test_merge_with(self):
2107    original = DynamicRaggedShape.from_lengths([2, (3, 5), 6])
2108    result = original._merge_with(original)
2109    self.assertShapeEq(result, original)
2110
2111  def test_merge_with_spec(self):
2112    original = DynamicRaggedShape.from_lengths([2, (3, 5), 6],
2113                                               dtype=dtypes.int64)
2114    spec = DynamicRaggedShape.Spec(
2115        row_partitions=[
2116            RowPartitionSpec(nrows=2,
2117                             nvals=8,
2118                             dtype=dtypes.int64)
2119        ],
2120        static_inner_shape=tensor_shape.TensorShape([8, 6]),
2121        dtype=dtypes.int64)
2122    result = original._merge_with_spec(spec)
2123    self.assertShapeEq(result, original)
2124
2125  def test_merge_with_spec_raises(self):
2126    original = DynamicRaggedShape.from_lengths([2, (3, 5), 6],
2127                                               dtype=dtypes.int64)
2128    spec = DynamicRaggedShape.Spec(
2129        row_partitions=[
2130            RowPartitionSpec(nrows=2,
2131                             nvals=8,
2132                             dtype=dtypes.int32)
2133        ],
2134        static_inner_shape=tensor_shape.TensorShape([8, 6]),
2135        dtype=dtypes.int32)
2136    with self.assertRaisesRegex(
2137        ValueError,
2138        'RowPartition and RowPartitionSpec are not compatible'):
2139      original._merge_with_spec(spec)
2140
2141  def test_merge_with_spec_uniform(self):
2142    original = DynamicRaggedShape.from_lengths(
2143        [2, (4, 4), 6], dtype=dtypes.int64)
2144    spec = DynamicRaggedShape.Spec._from_tensor_shape(
2145        tensor_shape.TensorShape([2, 4, 6]),
2146        num_row_partitions=0,
2147        dtype=dtypes.int64)
2148    result = original._merge_with_spec(spec)
2149    original = DynamicRaggedShape.from_lengths([2, 4, 6],
2150                                               num_row_partitions=1,
2151                                               dtype=dtypes.int64)
2152    self.assertShapeEq(result, original)
2153
2154  @parameterized.parameters([
2155      dict(
2156          doc='x.shape=[3, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
2157          x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]],
2158                                              dtype=np.int32),
2159          y=[[10], [20], [30]],
2160          expected=ragged_factory_ops.constant_value([[11, 12, 13], [],
2161                                                      [34, 35]])),
2162      dict(
2163          doc='x.shape=[3, (D1)]; y.shape=[]; bcast.shape=[3, (D1)]',
2164          x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]],
2165                                              dtype=np.int32),
2166          y=10,
2167          expected=ragged_factory_ops.constant_value([[11, 12, 13], [],
2168                                                      [14, 15]])),
2169      dict(
2170          doc='x.shape=[1, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]',
2171          x=ragged_factory_ops.constant_value([[1, 2, 3]], dtype=np.int32),
2172          y=[[10], [20], [30]],
2173          expected=ragged_factory_ops.constant_value(
2174              [[11, 12, 13], [21, 22, 23], [31, 32, 33]], dtype=np.int32)),
2175      dict(
2176          doc=('x.shape=[2, (D1), 1]; y.shape=[1, (D2)]; '
2177               'bcast.shape=[2, (D1), (D2)]'),
2178          x=ragged_factory_ops.constant_value([[[1], [2], [3]], [[4]]],
2179                                              ragged_rank=1),
2180          y=ragged_factory_ops.constant_value([[10, 20, 30]]),
2181          expected=ragged_factory_ops.constant_value([[[11, 21,
2182                                                        31], [12, 22, 32],
2183                                                       [13, 23, 33]],
2184                                                      [[14, 24, 34]]])),
2185      dict(
2186          doc=('x.shape=[2, (D1), 1]; y.shape=[1, 1, 4]; '
2187               'bcast.shape=[2, (D1), 4]'),
2188          x=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
2189                                              ragged_rank=1),
2190          y=[[[1, 2, 3, 4]]],
2191          expected=ragged_factory_ops.constant_value(
2192              [[[11, 12, 13, 14], [21, 22, 23, 24]], [[31, 32, 33, 34]]],
2193              ragged_rank=1)),
2194      dict(
2195          doc=('x.shape=[2, (D1), 2, 1]; y.shape=[2, (D2)]; '
2196               'bcast.shape=[2, (D1), (2), (D2)'),
2197          x=ragged_factory_ops.constant_value(
2198              [[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1),
2199          y=ragged_factory_ops.constant_value([[10, 20], [30]]),
2200          expected=ragged_factory_ops.constant_value([[[[11, 21], [32]],
2201                                                       [[13, 23], [34]]],
2202                                                      [[[15, 25], [36]]]])),
2203  ])
2204  def testRaggedDispatchImplWithBroadcasting(self, x, y, expected, doc):
2205    expected_rrank = getattr(expected, 'num_row_partitions', 0)
2206    x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
2207    y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
2208    result = dynamic_ragged_shape.ragged_binary_elementwise_op_impl(
2209        gen_math_ops.add_v2, x, y)
2210    result_rrank = getattr(result, 'num_row_partitions', 0)
2211    self.assertEqual(expected_rrank, result_rrank)
2212    if hasattr(expected, 'tolist'):
2213      expected = expected.tolist()
2214    self.assertAllEqual(result, expected)
2215
2216  def testDimensions(self):
2217    a = DynamicRaggedShape._from_inner_shape([1, 2, 3])
2218    self.assertAllEqual(1, a._dimension(0))
2219
2220  def testGetItemIsInstanceTensor(self):
2221    a = dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape([1, 2, 3])
2222    self.assertIsInstance(a[0], ops.Tensor)
2223
2224  @parameterized.parameters([
2225      dict(
2226          lengths=[2, 2],
2227          num_row_partitions=1,
2228          expected=[2, 2]),
2229      dict(lengths=[2, 2], num_row_partitions=0, expected=[2, 2]),
2230      dict(
2231          lengths=[2, (1, 2), 2], num_row_partitions=1, expected=[2, (1, 2), 2])
2232  ])
2233  def testStaticLengths(self,
2234                        lengths,
2235                        num_row_partitions,
2236                        expected,
2237                        expected_eager=None):
2238    a = DynamicRaggedShape.from_lengths(lengths)._with_num_row_partitions(
2239        num_row_partitions)
2240    actual = a.static_lengths()
2241    if context.executing_eagerly() and expected_eager is not None:
2242      self.assertAllEqual(expected_eager, actual)
2243    else:
2244      self.assertAllEqual(expected, actual)
2245
2246  def testStaticLengthsUnknown(self):
2247
2248    @def_function.function(
2249        input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2250    def foo(row_lengths):
2251      a = DynamicRaggedShape([RowPartition.from_row_lengths(row_lengths)], [6])
2252      actual = a.static_lengths()
2253      self.assertAllEqual([None, None], actual)
2254
2255    foo([3, 3])
2256
2257  def testStaticLengthsRankUnknown(self):
2258    # Note that the rank of the shape is unknown, so we can only provide a
2259    # prefix of the lengths.
2260    @def_function.function(
2261        input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2262    def foo(inner_shape):
2263      a = DynamicRaggedShape([RowPartition.from_row_lengths([3, 3])],
2264                             inner_shape)
2265      actual = a.static_lengths()
2266      self.assertAllEqual([2, (3, 3), ...], actual)
2267
2268    foo([6, 3])
2269
2270  def testReprRankKnown(self):
2271    a = DynamicRaggedShape.from_lengths([2, (1, 2), 3])
2272    actual = str(a)
2273    self.assertEqual(
2274        '<DynamicRaggedShape lengths=[2, (1, 2), 3] num_row_partitions=1>',
2275        actual)
2276
2277  def assertDimsEqual(self, x: tensor_shape.TensorShape,
2278                      y: tensor_shape.TensorShape):
2279    if x.rank is None:
2280      self.assertIsNone(
2281          y.rank,
2282          'x has an unknown rank, but y does not: x={}, y={}'.format(x, y))
2283      return
2284    self.assertIsNotNone(
2285        y.rank,
2286        'y has an unknown rank, but x does not: x={}, y={}'.format(x, y))
2287    self.assertAllEqual(x.as_list(), y.as_list())
2288
2289  def testToTensorShapeRankKnown(self):
2290    a = DynamicRaggedShape.from_lengths([2, (1, 2), 3])
2291    actual = a._to_tensor_shape()
2292    self.assertDimsEqual(tensor_shape.TensorShape([2, None, 3]), actual)
2293
2294  def testReprRankUnknown(self):
2295
2296    @def_function.function(
2297        input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2298    def foo(inner_shape):
2299      a = DynamicRaggedShape([RowPartition.from_row_lengths([3, 3])],
2300                             inner_shape)
2301      actual = str(a)
2302      self.assertEqual(
2303          '<DynamicRaggedShape lengths=[2, (3, 3), ...] num_row_partitions=1>',
2304          actual)
2305
2306    foo([6, 3])
2307
2308  def testToTensorShapeRankUnknown(self):
2309    @def_function.function(
2310        input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2311    def foo(inner_shape):
2312      a = DynamicRaggedShape([RowPartition.from_row_lengths([3, 3])],
2313                             inner_shape)
2314      actual = a._to_tensor_shape()
2315      self.assertDimsEqual(
2316          tensor_shape.TensorShape(None), actual)
2317
2318    foo([6, 3])
2319
2320  def testBroadcastDynamicShapeExtendedRankOne(self):
2321    a = DynamicRaggedShape._from_inner_shape([1])
2322    b = DynamicRaggedShape._from_inner_shape([3])
2323    (c, ac, bc) = dynamic_ragged_shape.broadcast_dynamic_shape_extended(a, b)
2324    expected_c = DynamicRaggedShape._from_inner_shape([3])
2325    self.assertShapeEq(c, expected_c)
2326    ac_result = ac.broadcast(constant_op.constant([4]))
2327    self.assertAllEqual(ac_result, [4, 4, 4])
2328    bc_result = bc.broadcast(constant_op.constant([4, 7, 1]))
2329    self.assertAllEqual(bc_result, [4, 7, 1])
2330
2331  def testBroadcastDynamicShapeExtendedRankOneRev(self):
2332    a = DynamicRaggedShape._from_inner_shape([3])
2333    b = DynamicRaggedShape._from_inner_shape([1])
2334    (c, ac, bc) = dynamic_ragged_shape.broadcast_dynamic_shape_extended(a, b)
2335    expected_c = DynamicRaggedShape._from_inner_shape([3])
2336    self.assertShapeEq(c, expected_c)
2337    bc_result = bc.broadcast(constant_op.constant([4]))
2338    self.assertAllEqual(bc_result, [4, 4, 4])
2339    ac_result = ac.broadcast(constant_op.constant([4, 7, 1]))
2340    self.assertAllEqual(ac_result, [4, 7, 1])
2341
2342  def testBroadcastDynamicShapeExtendedRankOneIdentity(self):
2343    a = DynamicRaggedShape._from_inner_shape([3])
2344    b = DynamicRaggedShape._from_inner_shape([3])
2345    (c, ac, bc) = dynamic_ragged_shape.broadcast_dynamic_shape_extended(a, b)
2346    expected_c = DynamicRaggedShape._from_inner_shape([3])
2347    self.assertShapeEq(c, expected_c)
2348    bc_result = bc.broadcast(constant_op.constant([4, 7, 1]))
2349    self.assertAllEqual(bc_result, [4, 7, 1])
2350    ac_result = ac.broadcast(constant_op.constant([4, 7, 1]))
2351    self.assertAllEqual(ac_result, [4, 7, 1])
2352
2353  def testFromGatherLayerIndexRaises(self):
2354    bad_gather_index = constant_op.constant([0.0, 0.5, 1.0])
2355    with self.assertRaisesRegex(ValueError, 'gather_index must be'):
2356      _LayerBroadcaster.from_gather_index(bad_gather_index)
2357
2358  ### Tests mostly for code coverage ###########################################
2359
2360  def testFindPreferredDtypeIntNone(self):
2361    actual = dynamic_ragged_shape._find_dtype(3, None)
2362    self.assertIsNone(actual)
2363
2364  @parameterized.parameters([
2365      dict(
2366          source_shape=lambda: DynamicRaggedShape._from_inner_shape([3]),
2367          target_shape=lambda: DynamicRaggedShape._from_inner_shape([3]),
2368          layer_broadcasters=lambda: [int],
2369          dtype=None,
2370          error_type=TypeError,
2371          error_regex=r'Not a LayerBroadcaster'),
2372      dict(
2373          source_shape=lambda: DynamicRaggedShape._from_inner_shape([3]),
2374          target_shape=lambda: DynamicRaggedShape._from_inner_shape([3]),
2375          layer_broadcasters=lambda: _LayerBroadcaster.from_gather_index(
2376              [0, 1, 2]),
2377          dtype=None,
2378          error_type=TypeError,
2379          error_regex=r'layer'),
2380      dict(
2381          source_shape=lambda: DynamicRaggedShape._from_inner_shape([3]),
2382          target_shape=lambda: None,
2383          layer_broadcasters=lambda:
2384          [_LayerBroadcaster.from_gather_index([0, 1, 2])],
2385          dtype=None,
2386          error_type=TypeError,
2387          error_regex='target_shape is not a DynamicRaggedShape'),
2388      dict(
2389          source_shape=lambda: None,
2390          target_shape=lambda: DynamicRaggedShape._from_inner_shape([3]),
2391          layer_broadcasters=lambda:
2392          [_LayerBroadcaster.from_gather_index([0, 1, 2])],
2393          dtype=None,
2394          error_type=TypeError,
2395          error_regex='source_shape is not a DynamicRaggedShape')
2396  ])
2397  def testBroadcasterInitRaises(self, source_shape, target_shape,
2398                                layer_broadcasters, dtype, error_type,
2399                                error_regex):
2400    source_shape = source_shape()
2401    target_shape = target_shape()
2402    layer_broadcasters = layer_broadcasters()
2403    with self.assertRaisesRegex(error_type, error_regex):
2404      dynamic_ragged_shape._Broadcaster(
2405          source_shape, target_shape, layer_broadcasters, dtype=dtype)
2406
2407  def testBroadcasterRepr(self):
2408    source_shape = DynamicRaggedShape(
2409        [RowPartition.from_row_splits(constant_op.constant([0, 1, 2]))],
2410        constant_op.constant([3]))
2411    target_shape = DynamicRaggedShape(
2412        [RowPartition.from_row_splits(constant_op.constant([0, 1, 2]))],
2413        constant_op.constant([3]))
2414    layer_broadcasters = [
2415        _LayerBroadcaster.from_gather_index(constant_op.constant([0, 1, 2])),
2416        _LayerBroadcaster.from_gather_index(constant_op.constant([0, 1, 2]))
2417    ]
2418    bc = dynamic_ragged_shape._Broadcaster(source_shape, target_shape,
2419                                           layer_broadcasters)
2420    actual = str(bc)
2421    self.assertRegex(actual, '.src_shape..DynamicRaggedShape')
2422
2423  def testBroadcasterWithDtype(self):
2424    source_shape = DynamicRaggedShape(
2425        [RowPartition.from_row_splits(constant_op.constant([0, 1, 2]))],
2426        constant_op.constant([3]))
2427    target_shape = DynamicRaggedShape(
2428        [RowPartition.from_row_splits(constant_op.constant([0, 1, 2]))],
2429        constant_op.constant([3]))
2430    layer_broadcasters = [
2431        _LayerBroadcaster.from_gather_index(constant_op.constant([0, 1, 2])),
2432        _LayerBroadcaster.from_gather_index(constant_op.constant([0, 1, 2]))
2433    ]
2434    bc = dynamic_ragged_shape._Broadcaster(
2435        source_shape, target_shape, layer_broadcasters, dtype=dtypes.int32)
2436
2437    bc2 = bc.with_dtype(dtypes.int64)
2438    self.assertEqual(bc2.dtype, dtypes.int64)
2439
2440  # TODO(martinz): This doesn't work for ragged_tensor_shape.
2441  # Uncomment when we switch over the implementation.
2442  #    dict(dtype=dtypes.int32)
2443  @parameterized.parameters([
2444      dict(dtype=dtypes.int64)
2445  ])
2446  def testBroadcasterWithDenseDType(self, dtype):
2447    a = constant_op.constant([[4]])
2448    b = RaggedTensor.from_row_splits([[2], [3], [4], [5]], [0, 3, 4])
2449    b = b.with_row_splits_dtype(dtype)
2450    c = a + b
2451    self.assertEqual(c.row_splits.dtype, dtype)
2452    d = b + a
2453    self.assertEqual(d.row_splits.dtype, dtype)
2454
2455  @parameterized.parameters([
2456      dict(dtype_left=dtypes.int64,
2457           dtype_right=dtypes.int32),
2458      dict(dtype_left=dtypes.int32,
2459           dtype_right=dtypes.int64)])
2460  def testBroadcastWithDifferentDenseShapeDTypes(self, dtype_left,
2461                                                 dtype_right):
2462    s_left = DynamicRaggedShape._from_inner_shape(
2463        constant_op.constant([4, 1], dtype_left))
2464    s_right = DynamicRaggedShape._from_inner_shape(
2465        constant_op.constant([1, 4], dtype_right))
2466    s_result = dynamic_ragged_shape.broadcast_dynamic_shape(s_left, s_right)
2467    self.assertEqual(s_result.dtype, dtypes.int64)
2468
2469  def testBroadcastFlatValuesToDenseExpand(self):
2470    source = RaggedTensor.from_uniform_row_length([0, 1, 2, 3], 2)
2471    target_shape = DynamicRaggedShape._from_inner_shape([1, 2, 2])
2472    broadcaster = dynamic_ragged_shape._get_broadcaster(
2473        DynamicRaggedShape.from_tensor(source), target_shape)
2474    flat_values = broadcaster.broadcast_flat_values(source)
2475    self.assertAllEqual(flat_values, [[[0, 1], [2, 3]]])
2476
2477  # TODO(edloper): Confirm that this is the expected behavior.
2478  def testBroadcastFlatValuesToDenseExpandInnerDimensionsFalse(self):
2479    source = RaggedTensor.from_uniform_row_length([0, 1, 2, 3], 2)
2480    target_shape = DynamicRaggedShape._from_inner_shape([1, 2, 2])
2481    broadcaster = dynamic_ragged_shape._get_broadcaster(
2482        DynamicRaggedShape.from_tensor(source), target_shape)
2483    flat_values = broadcaster.broadcast_flat_values(
2484        source, inner_dimensions=False)
2485    self.assertAllEqual(flat_values, [[0, 1], [2, 3]])
2486
2487  def testGetLayerBroadcastersFromRPSRaisesTypeError(self):
2488    with self.assertRaisesRegex(TypeError, 'Not a _LayerBroadcaster'):
2489      dynamic_ragged_shape._get_layer_broadcasters_from_rps(int, [], [])
2490
2491  def testGetBroadcasterRankDrop(self):
2492    with self.assertRaisesRegex(ValueError, 'Cannot broadcast'):
2493      a = DynamicRaggedShape._from_inner_shape([3, 4, 5])
2494      b = DynamicRaggedShape._from_inner_shape([4, 5])
2495      dynamic_ragged_shape._get_broadcaster(a, b)
2496
2497  @parameterized.parameters([
2498      dict(
2499          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2500          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2501          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2502          b_1=lambda: None,
2503          error_type=TypeError,
2504          error_regex='b_1 should be a RowPartition'),
2505      dict(
2506          ac_0=lambda: None,
2507          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2508          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2509          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2510          error_type=TypeError,
2511          error_regex='ac_0 should be a _LayerBroadcaster'),
2512      dict(
2513          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2514          bc_0=lambda: None,
2515          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2516          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2517          error_type=TypeError,
2518          error_regex='bc_0 should be a _LayerBroadcaster'),
2519      dict(
2520          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2521          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2522          a_1=lambda: None,
2523          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2524          error_type=TypeError,
2525          error_regex='a_1 should be a RowPartition')
2526  ])
2527  def testBroadcastDynamicShapeNextLayerHalfRaggedRaises(
2528      self, ac_0, bc_0, a_1, b_1, error_type, error_regex):
2529    ac_0 = ac_0()
2530    bc_0 = bc_0()
2531    a_1 = a_1()
2532    b_1 = b_1()
2533    with self.assertRaisesRegex(error_type, error_regex):
2534      dynamic_ragged_shape._broadcast_dynamic_shape_next_layer_half_ragged(
2535          ac_0, bc_0, a_1, b_1)
2536
2537  @parameterized.parameters([
2538      dict(
2539          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2540          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2541          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2542          b_1=lambda: None,
2543          error_type=TypeError,
2544          error_regex='b_1 should be a RowPartition'),
2545      dict(
2546          ac_0=lambda: None,
2547          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2548          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2549          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2550          error_type=TypeError,
2551          error_regex='ac_0 should be a _LayerBroadcaster'),
2552      dict(
2553          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2554          bc_0=lambda: None,
2555          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2556          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2557          error_type=TypeError,
2558          error_regex='bc_0 should be a _LayerBroadcaster'),
2559      dict(
2560          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2561          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2562          a_1=lambda: None,
2563          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2564          error_type=TypeError,
2565          error_regex='a_1 should be a RowPartition')
2566  ])
2567  def testBroadcastDynamicShapeNextLayerBothUniformRaises(
2568      self, ac_0, bc_0, a_1, b_1, error_type, error_regex):
2569    ac_0 = ac_0()
2570    bc_0 = bc_0()
2571    a_1 = a_1()
2572    b_1 = b_1()
2573    with self.assertRaisesRegex(error_type, error_regex):
2574      dynamic_ragged_shape._broadcast_dynamic_shape_next_layer_both_uniform(
2575          ac_0, bc_0, a_1, b_1)
2576
2577  @parameterized.parameters([
2578      dict(
2579          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2580          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2581          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2582          b_1=lambda: None,
2583          error_type=TypeError,
2584          error_regex='b_1 should be a RowPartition'),
2585      dict(
2586          ac_0=lambda: None,
2587          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2588          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2589          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2590          error_type=TypeError,
2591          error_regex='ac_0 should be a _LayerBroadcaster'),
2592      dict(
2593          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2594          bc_0=lambda: None,
2595          a_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2596          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2597          error_type=TypeError,
2598          error_regex='bc_0 should be a _LayerBroadcaster'),
2599      dict(
2600          ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2601          bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]),
2602          a_1=lambda: None,
2603          b_1=lambda: RowPartition.from_row_splits([0, 1, 2]),
2604          error_type=TypeError,
2605          error_regex='a_1 should be a RowPartition')
2606  ])
2607  def testBroadcastDynamicShapeNextLayerRaises(self, ac_0, bc_0, a_1, b_1,
2608                                               error_type, error_regex):
2609    ac_0 = ac_0()
2610    bc_0 = bc_0()
2611    a_1 = a_1()
2612    b_1 = b_1()
2613    with self.assertRaisesRegex(error_type, error_regex):
2614      dynamic_ragged_shape._broadcast_dynamic_shape_next_layer(
2615          ac_0, bc_0, a_1, b_1)
2616
2617  @parameterized.parameters([
2618      dict(
2619          left_dtype=dtypes.int64,
2620          right_dtype=dtypes.int64,
2621          expected_dtype=dtypes.int64),
2622      dict(
2623          left_dtype=dtypes.int32,
2624          right_dtype=dtypes.int32,
2625          expected_dtype=dtypes.int32)
2626  ])
2627  def testAddingRowSplits(self, left_dtype, right_dtype, expected_dtype):
2628    x = ragged_factory_ops.constant([[1, 2]]).with_row_splits_dtype(left_dtype)
2629    y = ragged_factory_ops.constant([[1, 2]]).with_row_splits_dtype(right_dtype)
2630    z = math_ops.add(x, y)
2631    self.assertEqual(z.row_splits.dtype, expected_dtype)
2632
2633  @parameterized.parameters([
2634      dict(left_dtype=dtypes.int32, right_dtype=dtypes.int64),
2635      dict(left_dtype=dtypes.int64, right_dtype=dtypes.int32),
2636  ])
2637  def testAddingRowSplitsError(self, left_dtype, right_dtype):
2638    x = ragged_factory_ops.constant([[1, 2]]).with_row_splits_dtype(left_dtype)
2639    y = ragged_factory_ops.constant([[1, 2]]).with_row_splits_dtype(right_dtype)
2640    with self.assertRaisesRegex(
2641        ValueError, 'Input RaggedTensors have mismatched row_splits dtypes'):
2642      math_ops.add(x, y)
2643
2644  def testAddRowPartitionsInvalidV1(self):
2645    if not context.executing_eagerly():
2646      return
2647
2648    with self.assertRaisesRegex(
2649        (errors_impl.InvalidArgumentError, ValueError),
2650        'Last row partition does not match flat_values.'):
2651      rt = ragged_factory_ops.constant([[3], [4, 5], [6]])
2652      rt_shape = DynamicRaggedShape.from_tensor(rt)
2653      new_flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
2654      rt_shape._add_row_partitions(new_flat_values, validate=True)
2655
2656  # Example #1:
2657  # [2, (3, 1), 5], num_row_partitions = 1, outer_axis = 0, inner_axis = 1.
2658  # Result: [4, 5], num_row_partitions = 0.
2659  # Example #2:
2660  # [2, (2, 1), (7, 8, 9), 5], num_row_partitions = 2, outer_axis = 1,
2661  #     inner_axis = 2.
2662  # Result: [2, (15, 9), 5], num_row_partitions = 1.
2663  # Example #3:
2664  # [2, (2, 1), (7, 8, 9), 5], num_row_partitions = 2, outer_axis = 0,
2665  #     inner_axis = 1.
2666  # Result: [(7, 8, 9), 5], num_row_partitions = 1.
2667  # Here, we are merging the tail of the row_partitions,
2668  # but the inner_shape is unchanged.
2669
2670  @parameterized.parameters([
2671      # NOOP
2672      dict(
2673          lengths=[2, (3, 1), 5],
2674          num_row_partitions=1,
2675          outer_axis=1,
2676          inner_axis=1,
2677          expected_lengths=[2, (3, 1), 5],
2678          expected_num_row_partitions=1),
2679      # Where num_row_partitions == 0
2680      dict(
2681          lengths=[2, 7, 5, 4],
2682          num_row_partitions=0,
2683          outer_axis=1,
2684          inner_axis=2,
2685          expected_lengths=[2, 35, 4],
2686          expected_num_row_partitions=0),
2687      # Where inner_axis <= self.num_row_partitions
2688      dict(
2689          lengths=[2, (3, 1), 5],
2690          num_row_partitions=1,
2691          outer_axis=0,
2692          inner_axis=1,
2693          expected_lengths=[4, 5],
2694          expected_num_row_partitions=0),
2695      dict(
2696          lengths=[2, (2, 1), (7, 8, 9), 5],
2697          num_row_partitions=2,
2698          outer_axis=1,
2699          inner_axis=2,
2700          expected_lengths=[2, (15, 9), 5],
2701          expected_num_row_partitions=1),
2702      # outer_axis > num_row_partitions (only inner_shape changed)
2703      dict(
2704          lengths=[2, (1, 2), 5, 3],
2705          num_row_partitions=1,
2706          outer_axis=2,
2707          inner_axis=3,
2708          expected_lengths=[2, (1, 2), 15],
2709          expected_num_row_partitions=1),
2710      # outer_axis <= num_row_partitions
2711      # inner_axis > num_row_partitions (everything changes)
2712      # (If outer_axis == 0, all row_partitions are truncated).
2713      dict(
2714          lengths=[2, (2, 1), (7, 8, 9), 2, 5],
2715          num_row_partitions=2,
2716          outer_axis=0,
2717          inner_axis=3,
2718          expected_lengths=[48, 5],
2719          expected_num_row_partitions=0),
2720      dict(
2721          lengths=[2, (2, 1), (7, 8, 9), 2, 5],
2722          num_row_partitions=2,
2723          outer_axis=1,
2724          inner_axis=3,
2725          expected_lengths=[2, (30, 18), 5],
2726          expected_num_row_partitions=1),
2727  ])
2728  def test_merge_dims(self, lengths, num_row_partitions, outer_axis, inner_axis,
2729                      expected_lengths, expected_num_row_partitions):
2730    original = DynamicRaggedShape.from_lengths(
2731        lengths, num_row_partitions=num_row_partitions)
2732    actual = original._merge_dims(outer_axis, inner_axis)
2733    expected = DynamicRaggedShape.from_lengths(expected_lengths,
2734                                               expected_num_row_partitions)
2735    self.assertShapeEq(actual, expected)
2736
2737  def test_merge_dims_special(self):
2738    rt = ragged_factory_ops.constant([[[1, 2], [3]], [[4]]])
2739    original = DynamicRaggedShape.from_tensor(rt)
2740    actual = original._merge_dims(0, 1)
2741    self.assertAllEqual(actual[0], 3)
2742
2743  def testGetItemRankNoneTruncate(self):
2744    @def_function.function(
2745        input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2746    def foo(x):
2747      rts = DynamicRaggedShape.from_tensor(x)
2748      actual = rts[:1]
2749      self.assertShapeEq(rts, actual)
2750
2751    foo([1, 2, 3])
2752
2753  def test_dataset_only_dense(self):
2754    ragged = DynamicRaggedShape.from_lengths([4, 5, 2, 3])
2755    dataset_ops.DatasetV2.from_tensors(ragged)
2756
2757  def test_dataset_only_ragged(self):
2758    ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5), 2, 3])
2759    dataset_ops.DatasetV2.from_tensors(ragged)
2760
2761  def test_ragged_dataset(self):
2762    rt = RaggedTensor.from_row_splits(array_ops.zeros([5, 2, 3]), [0, 3, 5])
2763    dataset_ops.DatasetV2.from_tensors(rt)
2764
2765  def test_ones_shape(self):
2766    ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5)])
2767    ones = dynamic_ragged_shape.ones(ragged, dtype=bool)
2768    sh2 = DynamicRaggedShape.from_tensor(ones)
2769    self.assertAllEqual(sh2.static_lengths(), [4, (3, 0, 4, 5)])
2770
2771  def test_dataset_only_simple_ragged(self):
2772    ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5)])
2773    dataset_ops.DatasetV2.from_tensors(ragged)
2774
2775  # ValueError: _to_batched_tensor_list doesn't support ragged_rank=0 yet
2776  def test_unbatch_batch_dense(self):
2777    ragged = DynamicRaggedShape.from_lengths([4, 5, 2, 3])
2778    ds = dataset_ops.DatasetV2.from_tensors(ragged)
2779    dsu = ds.unbatch()
2780    if context.executing_eagerly():
2781      values = list(dsu)
2782      self.assertAllEqual(values[0].static_lengths(), [5, 2, 3])
2783      self.assertAllEqual(values[2].static_lengths(), [5, 2, 3])
2784
2785    dsb = dsu.batch(2)
2786    if context.executing_eagerly():
2787      valuesb = list(dsb)
2788      self.assertAllEqual(valuesb[0].static_lengths(), [2, 5, 2, 3])
2789      self.assertAllEqual(valuesb[1].static_lengths(), [2, 5, 2, 3])
2790
2791  def test_unbatch_batch_values_shape_0(self):
2792    batched = DynamicRaggedShape.from_lengths([2])
2793    batch_size = 2
2794    ds = dataset_ops.Dataset.from_tensors(batched)
2795    ds2 = ds.unbatch()
2796    if context.executing_eagerly():
2797      v = list(ds2.batch(batch_size))
2798      self.assertAllEqual(v[0], batched)
2799
2800  def test_unbatch_batch_values_shape_1(self):
2801    batched = DynamicRaggedShape.from_lengths([2, 3])
2802    rebatched = DynamicRaggedShape.from_lengths([2, 3], num_row_partitions=1)
2803
2804    batch_size = 2
2805    ds = dataset_ops.Dataset.from_tensors(batched)
2806    ds2 = ds.unbatch()
2807    if context.executing_eagerly():
2808      v = list(ds2.batch(batch_size))
2809      self.assertAllEqual(v[0], rebatched)
2810
2811  def test_unbatch_dense_matrix(self):
2812    ragged = DynamicRaggedShape.from_lengths([2, 3])
2813    ds = dataset_ops.DatasetV2.from_tensors(ragged)
2814    dsu = ds.unbatch()
2815    if context.executing_eagerly():
2816      values = list(dsu)
2817      self.assertAllEqual(values[0].static_lengths(), [3])
2818      self.assertAllEqual(values[1].static_lengths(), [3])
2819
2820  def test_unbatch_dense_vector(self):
2821    ragged = DynamicRaggedShape.from_lengths([3])
2822    ds = dataset_ops.DatasetV2.from_tensors(ragged)
2823    dsu = ds.unbatch()
2824    if context.executing_eagerly():
2825      values = list(dsu)
2826      self.assertAllEqual(values[0].static_lengths(), [])
2827      self.assertAllEqual(values[1].static_lengths(), [])
2828
2829  def test_unbatch_ragged(self):
2830    ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5), 2, 3])
2831    ds = dataset_ops.DatasetV2.from_tensors(ragged)
2832    dsu = ds.unbatch()
2833    if context.executing_eagerly():
2834      dsu.__iter__()
2835
2836  def test_unbatch_batch_ragged(self):
2837    ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5), 2, 3])
2838    ds = dataset_ops.DatasetV2.from_tensors(ragged)
2839    dsu = ds.unbatch()
2840    if context.executing_eagerly():
2841      values = list(dsu)
2842      self.assertAllEqual(values[0].static_lengths(), [3, 2, 3])
2843      self.assertAllEqual(values[2].static_lengths(), [4, 2, 3])
2844
2845    dsb = dsu.batch(2)
2846    if context.executing_eagerly():
2847      valuesb = list(dsb)
2848      self.assertAllEqual(valuesb[0].static_lengths(), [2, (3, 0), 2, 3])
2849      self.assertAllEqual(valuesb[1].static_lengths(), [2, (4, 5), 2, 3])
2850
2851
2852class DynamicRaggedShapeErrorTest(parameterized.TestCase):
2853
2854  @parameterized.parameters([
2855      # Broadcast [1, 2, (1, 2)] to [1, 2, (2, 1)] (FAIL)
2856      dict(
2857          origin_lengths=[2, 1, (1, 2)],
2858          origin_values=[2, 3, 5],
2859          expected_lengths=[1, 2, (2, 1)]),
2860      # Broadcast [2, 1, (1, 1)] -> [2, 1, (5, 5)] (UNSUPPORTED)
2861      dict(
2862          origin_lengths=[2, 1, (1, 1)],
2863          origin_values=[2, 3],
2864          expected_lengths=[2, 1, (5, 5)]),
2865      # Broadcast [1, 2, (1, 2)] to [2, 2, (2, 1, 1, 2)] (FAIL)
2866      dict(
2867          origin_lengths=[1, 2, (1, 2)],
2868          origin_values=[2, 3, 5],
2869          expected_lengths=[2, 2, (2, 1, 1, 2)]),
2870      # Broadcast w.shape = [2,1,(1,3)] to w'.shape = [2,1,(3,3)] (UNSUPPORTED)
2871      dict(
2872          origin_lengths=[2, 1, (1, 3)],
2873          origin_values=[2, 3, 5, 7],  # [[[2]], [[3, 5, 7]]]
2874          expected_lengths=[2, 1, (3, 3)]),
2875  ])
2876  def testBroadcastRaggedError(self, origin_lengths, origin_values,
2877                               expected_lengths):
2878    # I pulled this out of the tensorflow test case, so that I could have
2879    # more control.
2880    # However this error is being generated, it confuses assertRaises,
2881    # but it exists.
2882    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2883                                r'Cannot broadcast'):
2884      # with self.assertRaisesRegex(errors.InvalidArgumentError,
2885      #                             r"Cannot broadcast"):
2886      sess = session.Session()
2887      with sess.as_default():
2888        origin = _to_ragged_tensor_from_lengths(origin_values, origin_lengths)
2889        expected_shape = DynamicRaggedShape.from_lengths(expected_lengths)
2890
2891        rt = dynamic_ragged_shape.broadcast_to(origin, expected_shape)
2892        sess.run([rt])
2893
2894  @parameterized.parameters([
2895      # nvals and nrows don't match (3 != 4) dynamically
2896      dict(
2897          row_partitions=lambda: [  # pylint: disable=g-long-lambda
2898              RowPartition.from_uniform_row_length(1, 3, nrows=3),
2899              RowPartition.from_uniform_row_length(1, 4, nrows=4)
2900          ],
2901          inner_shape=lambda: [4],
2902          validate=True,
2903          error_regex='RowPartitions in DynamicRaggedShape do not'),
2904      # nvals and inner_shape[0] don't match (3 != 4) dynamically
2905      dict(
2906          row_partitions=lambda: [  # pylint: disable=g-long-lambda
2907              RowPartition.from_uniform_row_length(1, 3, nrows=3),
2908          ],
2909          inner_shape=lambda: [4],
2910          validate=True,
2911          error_regex='Last row partition does not match inner_shape.'),
2912  ])
2913  def testConstructorRaisesDynamic(self,
2914                                   row_partitions,
2915                                   inner_shape,
2916                                   error_regex,
2917                                   validate=False,
2918                                   dtype=None):
2919    with self.assertRaisesRegex((errors_impl.InvalidArgumentError, ValueError),
2920                                error_regex):
2921      sess = session.Session()
2922      with sess.as_default():
2923        row_partitions = row_partitions()
2924        inner_shape = inner_shape()
2925        rts = DynamicRaggedShape(
2926            row_partitions, inner_shape, dtype=dtype, validate=validate)
2927        sess.run([rts.inner_shape])
2928
2929  def testRankNone(self):
2930
2931    @def_function.function(
2932        input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2933    def foo(x):
2934      rts = DynamicRaggedShape._from_inner_shape(x)
2935      self.assertIsNone(rts.rank)
2936
2937    foo([3, 7, 5])
2938
2939  def testNumSlicesInDimensionRankNone(self):
2940    with self.assertRaisesRegex(ValueError, 'rank is undefined'):
2941
2942      @def_function.function(
2943          input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2944      def foo(x):
2945        rts = DynamicRaggedShape._from_inner_shape(x)
2946        rts._num_slices_in_dimension(-1)
2947
2948      foo([3, 7, 5])
2949
2950  def testGetItemRankNone(self):
2951    with self.assertRaisesRegex(ValueError, 'Rank must be known to'):
2952
2953      @def_function.function(
2954          input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2955      def foo(x):
2956        rts = DynamicRaggedShape._from_inner_shape(x)
2957        rts[-1]  # pylint: disable=pointless-statement
2958
2959      foo([3, 7, 5])
2960
2961  def testWithDenseRankRankNone(self):
2962    with self.assertRaisesRegex(ValueError, 'Rank must be known to'):
2963
2964      @def_function.function(
2965          input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2966      def foo(x):
2967        rts = DynamicRaggedShape._from_inner_shape(x)
2968        rts._with_inner_rank(1)
2969
2970      foo([3, 7, 5])
2971
2972  def testWithRaggedRankRankNone(self):
2973    with self.assertRaisesRegex(ValueError, 'Rank must be known to'):
2974
2975      @def_function.function(
2976          input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2977      def foo(x):
2978        rts = DynamicRaggedShape._from_inner_shape(x)
2979        rts._with_num_row_partitions(1)
2980
2981      foo([3, 7, 5])
2982
2983  def testAsRowPartitionsRankNone(self):
2984    # Error is readable, but does not match strings correctly.
2985    with self.assertRaisesRegex(ValueError, ''):
2986
2987      @def_function.function(
2988          input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
2989      def foo(x):
2990        rts = DynamicRaggedShape._from_inner_shape(x)
2991        rts._as_row_partitions()
2992
2993      foo([3, 7, 5])
2994
2995  def testBroadcastDynamicShapeExtendedRankNone(self):
2996    with self.assertRaisesRegex(ValueError,
2997                                'Unable to broadcast: unknown rank'):
2998
2999      @def_function.function(
3000          input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
3001      def foo(x):
3002        a = DynamicRaggedShape._from_inner_shape(x)
3003        b = DynamicRaggedShape._from_inner_shape([1, 1, 1])
3004        dynamic_ragged_shape.broadcast_dynamic_shape_extended(a, b)
3005
3006      foo([3, 7, 5])
3007
3008  def testBroadcastDynamicShapeUnmatchedTypes6432(self):
3009    shape_int64 = DynamicRaggedShape.from_lengths([3, (0, 2, 3)],
3010                                                  dtype=dtypes.int64)
3011    shape_int32 = DynamicRaggedShape.from_lengths([3, (0, 2, 3)],
3012                                                  dtype=dtypes.int32)
3013    with self.assertRaisesRegex(ValueError, "Dtypes don't match"):
3014      dynamic_ragged_shape.broadcast_dynamic_shape(shape_int64, shape_int32)
3015
3016  def testBroadcastDynamicShapeUnmatchedTypes3264(self):
3017    shape_int64 = DynamicRaggedShape.from_lengths([3, (0, 2, 3)],
3018                                                  dtype=dtypes.int64)
3019    shape_int32 = DynamicRaggedShape.from_lengths([3, (0, 2, 3)],
3020                                                  dtype=dtypes.int32)
3021    with self.assertRaisesRegex(ValueError, "Dtypes don't match"):
3022      dynamic_ragged_shape.broadcast_dynamic_shape(shape_int32, shape_int64)
3023
3024  def testGetIdentityBroadcasterRankNone(self):
3025    with self.assertRaisesRegex(ValueError, 'Shape must have a'):
3026
3027      @def_function.function(
3028          input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
3029      def foo(x):
3030        rts = DynamicRaggedShape._from_inner_shape(x)
3031        dynamic_ragged_shape._get_identity_broadcaster(rts)
3032
3033      foo([3, 7, 5])
3034
3035  def testLayerBroadcasterRepr(self):
3036    index = constant_op.constant([0, 1, 2], name='testLayerBroadcasterRepr')
3037    lb = _LayerBroadcaster.from_gather_index(index)
3038    actual = str(lb)
3039    self.assertRegex(actual, '.*Tensor.*, shape=.3... dtype=int32.')
3040
3041  def testGetBroadcasterRankNoneLeft(self):
3042    with self.assertRaisesRegex(ValueError, 'Rank of source and target must'):
3043
3044      @def_function.function(
3045          input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
3046      def foo(x):
3047        rts_a = DynamicRaggedShape._from_inner_shape(x)
3048        rts_b = DynamicRaggedShape._from_inner_shape(x)
3049        dynamic_ragged_shape._get_broadcaster(rts_a, rts_b)
3050
3051      foo([3, 7, 5])
3052
3053  def testFromTensorDType(self):
3054    x = ragged_factory_ops.constant([[1, 2]])
3055    self.assertEqual(x.row_splits.dtype, dtypes.int64)
3056    shape_x = DynamicRaggedShape.from_tensor(x)
3057    self.assertEqual(shape_x.dtype, dtypes.int64)
3058
3059  def testAddingRowSplits(self):
3060    x = ragged_factory_ops.constant([[1, 2]])
3061    self.assertEqual(x.row_splits.dtype, dtypes.int64)
3062
3063    y = math_ops.add(x, x)
3064    self.assertEqual(y.row_splits.dtype, dtypes.int64)
3065
3066  def testHashingWithMask(self):
3067    inp_data = ragged_factory_ops.constant(
3068        [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']],
3069        dtype=dtypes.string)
3070    mask = math_ops.equal(inp_data, '')
3071    values = string_ops.string_to_hash_bucket_strong(
3072        inp_data, 3, name='hash', key=[0xDECAFCAFFE, 0xDECAFCAFFE])
3073    values = math_ops.add(values, array_ops.ones_like(values))
3074    local_zeros = array_ops.zeros_like(values)
3075    values = array_ops.where(mask, local_zeros, values)
3076
3077  def testAddRowPartitionsInvalid(self):
3078    with self.assertRaisesRegex(
3079        (errors_impl.InvalidArgumentError, ValueError),
3080        'Last row partition does not match flat_values.'):
3081      sess = session.Session()
3082      with sess.as_default():
3083        rt = ragged_factory_ops.constant([[3], [4, 5], [6]])
3084        rt_shape = DynamicRaggedShape.from_tensor(rt)
3085        new_flat_values = constant_op.constant(['a', 'b', 'c'])
3086        rt2 = rt_shape._add_row_partitions(new_flat_values, validate=True)
3087        sess.run([rt2])
3088
3089
3090class DynamicRaggedShapeSpecTest(parameterized.TestCase):
3091
3092  def assertRowPartitionSpecEqual(self,
3093                                  a: RowPartitionSpec,
3094                                  b: RowPartitionSpec,
3095                                  msg='') -> None:
3096    self.assertEqual(a.nrows, b.nrows, msg)
3097    self.assertEqual(a.nvals, b.nvals, msg)
3098    self.assertEqual(a.uniform_row_length, b.uniform_row_length, msg)
3099    self.assertEqual(a.dtype, b.dtype, msg)
3100
3101  def assertTensorShapeEqual(self, a: tensor_shape.TensorShape,
3102                             b: tensor_shape.TensorShape) -> None:
3103    self.assertEqual(a, b)
3104
3105  def assertTensorSpecEqual(self,
3106                            a: tensor_spec.TensorSpec,
3107                            b: tensor_spec.TensorSpec) -> None:
3108    self.assertTensorShapeEqual(a.shape, b.shape)
3109    self.assertEqual(a.dtype, b.dtype)
3110
3111  def assertDynamicRaggedShapeSpecEqual(self,
3112                                        a: DynamicRaggedShape.Spec,
3113                                        b: DynamicRaggedShape.Spec) -> None:
3114    self.assertTensorShapeEqual(a._static_inner_shape, b._static_inner_shape)
3115    self.assertTensorSpecEqual(a._inner_shape, b._inner_shape)
3116    for i, (a, b) in enumerate(zip(a._row_partitions, b._row_partitions)):
3117      self.assertRowPartitionSpecEqual(a, b, 'Error in partition ' + str(i))
3118
3119  @parameterized.parameters([
3120      # Unknown dimension
3121      dict(
3122          shape=tensor_shape.TensorShape(None),
3123          num_row_partitions=1,
3124          dtype=dtypes.int32,
3125          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3126              row_partitions=[
3127                  RowPartitionSpec(
3128                      nrows=None,
3129                      nvals=None,
3130                      uniform_row_length=None,
3131                      dtype=dtypes.int32),
3132                  RowPartitionSpec(
3133                      nrows=None,
3134                      nvals=None,
3135                      uniform_row_length=None,
3136                      dtype=dtypes.int32)
3137              ],
3138              static_inner_shape=tensor_shape.TensorShape(None),
3139              dtype=dtypes.int32)),
3140      # Unknown dimension, dense
3141      dict(
3142          shape=tensor_shape.TensorShape(None),
3143          num_row_partitions=0,
3144          dtype=dtypes.int32,
3145          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3146              row_partitions=[],
3147              static_inner_shape=tensor_shape.TensorShape(None),
3148              dtype=dtypes.int32)),
3149      # Scalar
3150      dict(
3151          shape=tensor_shape.TensorShape([]),
3152          num_row_partitions=0,
3153          dtype=dtypes.int32,
3154          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3155              row_partitions=[],
3156              static_inner_shape=tensor_shape.TensorShape([]),
3157              dtype=dtypes.int32)),
3158      # Vector
3159      dict(
3160          shape=tensor_shape.TensorShape([7]),
3161          num_row_partitions=0,
3162          dtype=dtypes.int32,
3163          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3164              row_partitions=[],
3165              static_inner_shape=tensor_shape.TensorShape([7]),
3166              dtype=dtypes.int32)),
3167      # Generic
3168      dict(
3169          shape=tensor_shape.TensorShape([5, 3, None, 4, 2, 5]),
3170          num_row_partitions=3,
3171          dtype=dtypes.int32,
3172          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3173              row_partitions=[
3174                  RowPartitionSpec(
3175                      nrows=5,
3176                      nvals=15,
3177                      uniform_row_length=3,
3178                      dtype=dtypes.int32),
3179                  RowPartitionSpec(
3180                      nrows=15,
3181                      nvals=None,
3182                      uniform_row_length=None,
3183                      dtype=dtypes.int32),
3184                  RowPartitionSpec(
3185                      nrows=None,
3186                      nvals=None,
3187                      uniform_row_length=4,
3188                      dtype=dtypes.int32)
3189              ],
3190              static_inner_shape=tensor_shape.TensorShape([None, 2, 5]),
3191              dtype=dtypes.int32)),
3192      # Generic, Dense
3193      dict(
3194          shape=tensor_shape.TensorShape([5, 3, None, 4, 2, 5]),
3195          num_row_partitions=0,
3196          dtype=dtypes.int32,
3197          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3198              row_partitions=[],
3199              static_inner_shape=tensor_shape.TensorShape(
3200                  [5, 3, None, 4, 2, 5]),
3201              dtype=dtypes.int32)),
3202  ])
3203  def test_from_tensor_shape(self, shape, num_row_partitions, dtype, expected):
3204    spec = DynamicRaggedShape.Spec._from_tensor_shape(shape, num_row_partitions,
3205                                                      dtype)
3206    self.assertDynamicRaggedShapeSpecEqual(spec, expected)
3207
3208  @parameterized.parameters([
3209      # Ridiculous DType.
3210      dict(
3211          shape=tensor_shape.TensorShape(None),
3212          num_row_partitions=1,
3213          dtype=dtypes.float32,
3214          error_type=ValueError,
3215          error_regex='dtype must be tf.int32 or tf.int64'),
3216      # num_row_partitions positive for scalar.
3217      dict(
3218          shape=tensor_shape.TensorShape([]),
3219          num_row_partitions=1,
3220          dtype=dtypes.int32,
3221          error_type=ValueError,
3222          error_regex='num_row_partitions should be zero ' +
3223          'if shape is a scalar or vector.'),
3224      dict(
3225          shape=tensor_shape.TensorShape([1, 2, 3]),
3226          num_row_partitions=3,
3227          dtype=dtypes.int32,
3228          error_type=ValueError,
3229          error_regex='num_row_partitions must be less than rank')
3230  ])
3231  def test_from_tensor_shape_raises(self, shape, num_row_partitions, dtype,
3232                                    error_type, error_regex):
3233    with self.assertRaisesRegex(error_type, error_regex):
3234      DynamicRaggedShape.Spec._from_tensor_shape(shape, num_row_partitions,
3235                                                 dtype)
3236
3237  def test_from_tensor_shape_raises_dtype(self):
3238    with self.assertRaisesRegex(ValueError,
3239                                'dtype must be tf.int32 or tf.int64'):
3240      DynamicRaggedShape.Spec._from_tensor_shape(
3241          [], tensor_shape.TensorShape([1, 2, 3]), dtypes.float32)
3242
3243  def test_from_row_partition_inner_shape_and_dtype_raises_dtype(self):
3244    with self.assertRaisesRegex(
3245        ValueError, r'dtype of .* is .*int64.*: expected .*int32.*'):
3246      DynamicRaggedShape.Spec(
3247          row_partitions=[
3248              RowPartitionSpec(
3249                  nrows=None,
3250                  nvals=None,
3251                  uniform_row_length=None,
3252                  dtype=dtypes.int32),
3253              RowPartitionSpec(
3254                  nrows=None,
3255                  nvals=None,
3256                  uniform_row_length=None,
3257                  dtype=dtypes.int64)
3258          ],
3259          static_inner_shape=tensor_shape.TensorShape(None),
3260          dtype=dtypes.int32)
3261
3262  def test_ranks(self):
3263    spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3264        shape=tensor_shape.TensorShape([5, None, 7, 4, 2, 5]),
3265        num_row_partitions=2,
3266        dtype=dtypes.int32)
3267
3268    self.assertEqual(spec.inner_rank, 4)
3269    self.assertEqual(spec.num_row_partitions, 2)
3270    self.assertEqual(spec.rank, 6)
3271
3272  def test_dimension_simple(self):
3273    spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3274        shape=tensor_shape.TensorShape([5, None, 7, 4, 2, 5]),
3275        num_row_partitions=2,
3276        dtype=dtypes.int32)
3277
3278    self.assertEqual(spec._dimension(0), 5)
3279    self.assertIsNone(spec._dimension(1))
3280    self.assertEqual(spec._dimension(2), 7)
3281    self.assertEqual(spec._dimension(3), 4)
3282    self.assertEqual(spec._dimension(4), 2)
3283    self.assertEqual(spec._dimension(5), 5)
3284
3285  @parameterized.parameters([
3286      dict(
3287          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3288              None, 0, dtypes.int32),
3289          dimension=0),
3290      dict(
3291          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3292              None, 0, dtypes.int32),
3293          dimension=1),
3294  ])
3295  def test_dimension_none(self, spec, dimension):
3296    actual = spec._dimension(dimension)
3297    self.assertIsNone(actual)
3298
3299  @parameterized.parameters([
3300      # Scalar.
3301      dict(
3302          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3303              [], 0, dtypes.int32),
3304          dimension=0,
3305          error_type=ValueError,
3306          error_regex='Index out of range: 0.'),
3307      # Scalar.
3308      dict(
3309          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3310              [], 0, dtypes.int32),
3311          dimension=1,
3312          error_type=ValueError,
3313          error_regex='Index out of range: 1.'),
3314  ])
3315  def test_dimension_raises(self, spec, dimension, error_type, error_regex):
3316    with self.assertRaisesRegex(error_type, error_regex):
3317      spec._dimension(dimension)
3318
3319  def test_num_slices_in_dimension_ragged(self):
3320    spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3321        shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]),
3322        num_row_partitions=2,
3323        dtype=dtypes.int32)
3324
3325    self.assertEqual(spec._num_slices_in_dimension(0), 5)
3326    self.assertEqual(spec._num_slices_in_dimension(1), 5 * 3)
3327    self.assertEqual(spec._num_slices_in_dimension(2), 5 * 3 * 7)
3328    self.assertEqual(spec._num_slices_in_dimension(3), 5 * 3 * 7 * 4)
3329    self.assertIsNone(spec._num_slices_in_dimension(4))
3330    self.assertIsNone(spec._num_slices_in_dimension(5))
3331    self.assertIsNone(spec._num_slices_in_dimension(-2))
3332
3333  def test_num_slices_in_dimension_ragged_alt(self):
3334    spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3335        shape=tensor_shape.TensorShape([5, 3, None, 2]),
3336        num_row_partitions=3,
3337        dtype=dtypes.int32)
3338
3339    self.assertEqual(spec._num_slices_in_dimension(0), 5)
3340    self.assertEqual(spec._num_slices_in_dimension(1), 5 * 3)
3341    self.assertIsNone(spec._num_slices_in_dimension(2))
3342    self.assertIsNone(spec._num_slices_in_dimension(3))
3343
3344  def test_num_slices_in_dimension_dense_known(self):
3345    spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3346        [5, 3, 4], 0, dtypes.int32)
3347
3348    self.assertEqual(spec._num_slices_in_dimension(0), 5)
3349    self.assertEqual(spec._num_slices_in_dimension(1), 15)
3350    self.assertEqual(spec._num_slices_in_dimension(2), 60)
3351
3352  @parameterized.parameters([
3353      dict(
3354          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3355              None, 0, dtypes.int32),
3356          dimension='CRAZY',
3357          error_type=TypeError,
3358          error_regex='axis must be an integer'),
3359      dict(
3360          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3361              None, 0, dtypes.int32),
3362          dimension=-1,
3363          error_type=ValueError,
3364          error_regex='axis=-1 may only be negative' +
3365          ' if rank is statically known.')
3366  ])
3367  def test_num_slices_in_dimension_raises(self, spec, dimension, error_type,
3368                                          error_regex):
3369    with self.assertRaisesRegex(error_type, error_regex):
3370      spec._num_slices_in_dimension(dimension)
3371
3372  def test_with_dtype(self):
3373    spec = DynamicRaggedShape.Spec._from_tensor_shape(
3374        shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]),
3375        num_row_partitions=2,
3376        dtype=dtypes.int32)
3377    actual = spec.with_dtype(dtypes.int64)
3378    self.assertEqual(actual.dtype, dtypes.int64)
3379    self.assertEqual(actual._row_partitions[0].dtype, dtypes.int64)
3380    self.assertEqual(actual._row_partitions[1].dtype, dtypes.int64)
3381
3382  @parameterized.parameters([
3383      dict(
3384          original=DynamicRaggedShape.Spec._from_tensor_shape(
3385              shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]),
3386              num_row_partitions=2,
3387              dtype=dtypes.int32),
3388          num_row_partitions=3,
3389          expected=DynamicRaggedShape.Spec._from_tensor_shape(
3390              shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]),
3391              num_row_partitions=3,
3392              dtype=dtypes.int32)),
3393      dict(
3394          original=DynamicRaggedShape.Spec._from_tensor_shape(
3395              shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]),
3396              num_row_partitions=2,
3397              dtype=dtypes.int32),
3398          num_row_partitions=1,
3399          expected=DynamicRaggedShape.Spec._from_tensor_shape(
3400              shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]),
3401              num_row_partitions=1,
3402              dtype=dtypes.int32)),
3403  ])
3404  def test_with_num_row_partitions(self, original, num_row_partitions,
3405                                   expected):
3406    actual = original._with_num_row_partitions(num_row_partitions)
3407    self.assertDynamicRaggedShapeSpecEqual(actual, expected)
3408
3409  @parameterized.parameters([
3410      dict(
3411          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3412              None, 0, dtypes.int32),
3413          num_row_partitions=2,
3414          error_type=ValueError,
3415          error_regex='Changing num_row_partitions with unknown rank'),
3416      dict(
3417          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3418              [1, 2, 3, 4], 0, dtypes.int32),
3419          num_row_partitions=4,
3420          error_type=ValueError,
3421          error_regex='Number of row partitions too large'),
3422      dict(
3423          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3424              [1, 2, 3, 4], 0, dtypes.int32),
3425          num_row_partitions=-3,
3426          error_type=ValueError,
3427          error_regex='Number of row partitions negative'),
3428  ])
3429  def test_with_num_row_partitions_raises(self, spec, num_row_partitions,
3430                                          error_type, error_regex):
3431    with self.assertRaisesRegex(error_type, error_regex):
3432      spec._with_num_row_partitions(num_row_partitions)
3433
3434  def test_truncate(self):
3435    spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
3436        shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]),
3437        num_row_partitions=2,
3438        dtype=dtypes.int32)
3439
3440    for new_rank in range(7):
3441      truncation = spec._truncate(new_rank)
3442      self.assertEqual(truncation.rank, new_rank)
3443      for i in range(new_rank):
3444        self.assertEqual(
3445            truncation._dimension(i), spec._dimension(i),
3446            'Mismatch on new_rank ' + str(new_rank) + ' on dimension ' + str(i))
3447
3448  def test_truncate_unknown(self):
3449    spec = DynamicRaggedShape.Spec(
3450        row_partitions=[
3451            RowPartitionSpec(
3452                nrows=3, nvals=7, uniform_row_length=None, dtype=dtypes.int32),
3453            RowPartitionSpec(
3454                nrows=7,
3455                nvals=None,
3456                uniform_row_length=None,
3457                dtype=dtypes.int32)
3458        ],
3459        static_inner_shape=tensor_shape.TensorShape(None),
3460        dtype=dtypes.int32)
3461    expected = DynamicRaggedShape.Spec(
3462        row_partitions=[
3463            RowPartitionSpec(
3464                nrows=3, nvals=7, uniform_row_length=None, dtype=dtypes.int32),
3465            RowPartitionSpec(
3466                nrows=7,
3467                nvals=None,
3468                uniform_row_length=None,
3469                dtype=dtypes.int32)
3470        ],
3471        static_inner_shape=tensor_shape.TensorShape([None, None]),
3472        dtype=dtypes.int32)
3473    actual = spec._truncate(4)
3474    self.assertDynamicRaggedShapeSpecEqual(actual, expected)
3475
3476  @parameterized.parameters([
3477      # Standard scalar
3478      dict(
3479          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3480              row_partitions=[],
3481              static_inner_shape=tensor_shape.TensorShape([]),
3482              dtype=dtypes.int32),
3483          expected=0),
3484      dict(
3485          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3486              row_partitions=[
3487                  RowPartitionSpec(
3488                      nrows=None,
3489                      nvals=None,
3490                      uniform_row_length=None,
3491                      dtype=dtypes.int64)
3492              ],
3493              static_inner_shape=tensor_shape.TensorShape([None]),
3494              dtype=dtypes.int64),
3495          expected=1),
3496      # Not knowing the shape of the inner shape is weird.
3497      dict(
3498          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3499              row_partitions=[
3500                  RowPartitionSpec(
3501                      nrows=None,
3502                      nvals=None,
3503                      uniform_row_length=None,
3504                      dtype=dtypes.int64)
3505              ],
3506              static_inner_shape=tensor_shape.TensorShape(None),
3507              dtype=dtypes.int64),
3508          expected=None),
3509  ])
3510  def test_inner_rank(self, spec, expected):
3511    actual = spec.inner_rank
3512    self.assertEqual(expected, actual)
3513
3514  @parameterized.parameters([
3515      # Standard scalar
3516      dict(
3517          other_spec=tensor_spec.TensorSpec([], dtypes.float32),
3518          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3519              row_partitions=[],
3520              static_inner_shape=tensor_shape.TensorShape([]),
3521              dtype=dtypes.int64)),
3522      dict(
3523          other_spec=ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32),
3524          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3525              row_partitions=[
3526                  RowPartitionSpec(nrows=None,
3527                                   nvals=None,
3528                                   uniform_row_length=None,
3529                                   dtype=dtypes.int64)
3530              ],
3531              static_inner_shape=tensor_shape.TensorShape([None]),
3532              dtype=dtypes.int64)),
3533      dict(
3534          other_spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3535              row_partitions=[
3536                  RowPartitionSpec(nrows=None,
3537                                   nvals=None,
3538                                   uniform_row_length=None,
3539                                   dtype=dtypes.int64)
3540              ],
3541              static_inner_shape=tensor_shape.TensorShape([None]),
3542              dtype=dtypes.int64),
3543          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3544              row_partitions=[
3545                  RowPartitionSpec(nrows=None,
3546                                   nvals=None,
3547                                   uniform_row_length=None,
3548                                   dtype=dtypes.int64)
3549              ],
3550              static_inner_shape=tensor_shape.TensorShape([None]),
3551              dtype=dtypes.int64)),
3552  ])
3553  def test_from_spec(self, other_spec, expected):
3554    actual = DynamicRaggedShape.Spec._from_spec(other_spec)
3555    self.assertDynamicRaggedShapeSpecEqual(expected, actual)
3556
3557  @parameterized.parameters([
3558      dict(
3559          row_partitions=[
3560              RowPartitionSpec(
3561                  nrows=None,
3562                  nvals=None,
3563                  uniform_row_length=None,
3564                  dtype=dtypes.int64)
3565          ],
3566          static_inner_shape=tensor_shape.TensorShape([None]),
3567          inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)),
3568      dict(
3569          row_partitions=[
3570              RowPartitionSpec(
3571                  nrows=None,
3572                  nvals=None,
3573                  uniform_row_length=None,
3574                  dtype=dtypes.int64)
3575          ],
3576          static_inner_shape=tensor_shape.TensorShape([None, 3]),
3577          inner_shape=tensor_spec.TensorSpec([2], dtypes.int64)),
3578      dict(
3579          row_partitions=[
3580              RowPartitionSpec(
3581                  nrows=6,
3582                  nvals=None,
3583                  uniform_row_length=None,
3584                  dtype=dtypes.int64)
3585          ],
3586          static_inner_shape=tensor_shape.TensorShape([None]),
3587          inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)),
3588      dict(
3589          row_partitions=[
3590              RowPartitionSpec(
3591                  nrows=6, nvals=60, uniform_row_length=10, dtype=dtypes.int64)
3592          ],
3593          static_inner_shape=tensor_shape.TensorShape([60]),
3594          inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)),
3595      dict(
3596          row_partitions=[
3597              RowPartitionSpec(
3598                  nrows=6, nvals=60, uniform_row_length=10, dtype=dtypes.int64),
3599              RowPartitionSpec(
3600                  nrows=60,
3601                  nvals=120,
3602                  uniform_row_length=None,
3603                  dtype=dtypes.int64)
3604          ],
3605          static_inner_shape=tensor_shape.TensorShape([120]),
3606          inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)),
3607      dict(
3608          row_partitions=[
3609              RowPartitionSpec(
3610                  nrows=6, nvals=60, uniform_row_length=10, dtype=dtypes.int64)
3611          ],
3612          static_inner_shape=tensor_shape.TensorShape(None),
3613          inner_shape=tensor_spec.TensorSpec([None], dtypes.int64))
3614  ])
3615  def test_constructor_idempotent(self, row_partitions, static_inner_shape,
3616                                  inner_shape):
3617    # The constructor detects if there is any additional information that
3618    # can be inferred from what is given.
3619    original = dynamic_ragged_shape.DynamicRaggedShape.Spec(
3620        row_partitions, static_inner_shape, inner_shape.dtype)
3621    self.assertTensorShapeEqual(original._static_inner_shape,
3622                                static_inner_shape)
3623    self.assertTensorSpecEqual(original._inner_shape, inner_shape)
3624    for i, (a, b) in enumerate(zip(original._row_partitions, row_partitions)):
3625      self.assertRowPartitionSpecEqual(a, b, 'Error in partition ' + str(i))
3626
3627  @parameterized.parameters([
3628      dict(
3629          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3630              row_partitions=[
3631                  RowPartitionSpec(
3632                      nrows=3,
3633                      nvals=None,
3634                      uniform_row_length=4,
3635                      dtype=dtypes.int64)
3636              ],
3637              static_inner_shape=tensor_shape.TensorShape([None]),
3638              dtype=dtypes.int64),
3639          expected_row_partitions=[
3640              RowPartitionSpec(
3641                  nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64)
3642          ],
3643          expected_static_inner_shape=tensor_shape.TensorShape([12]),
3644          expected_inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)),
3645      dict(
3646          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3647              row_partitions=[
3648                  RowPartitionSpec(
3649                      nrows=None,
3650                      nvals=None,
3651                      uniform_row_length=3,
3652                      dtype=dtypes.int64)
3653              ],
3654              static_inner_shape=tensor_shape.TensorShape([30]),
3655              dtype=dtypes.int64),
3656          expected_row_partitions=[
3657              RowPartitionSpec(
3658                  nrows=10, nvals=30, uniform_row_length=3, dtype=dtypes.int64)
3659          ],
3660          expected_static_inner_shape=tensor_shape.TensorShape([30]),
3661          expected_inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)),
3662      dict(
3663          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3664              row_partitions=[
3665                  RowPartitionSpec(
3666                      nrows=6,
3667                      nvals=None,
3668                      uniform_row_length=10,
3669                      dtype=dtypes.int64)
3670              ],
3671              static_inner_shape=tensor_shape.TensorShape([None]),
3672              dtype=dtypes.int64),
3673          expected_row_partitions=[
3674              RowPartitionSpec(
3675                  nrows=6, nvals=60, uniform_row_length=10, dtype=dtypes.int64)
3676          ],
3677          expected_static_inner_shape=tensor_shape.TensorShape([60]),
3678          expected_inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)),
3679      dict(
3680          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3681              row_partitions=[
3682                  RowPartitionSpec(
3683                      nrows=6,
3684                      nvals=None,
3685                      uniform_row_length=None,
3686                      dtype=dtypes.int64),
3687                  RowPartitionSpec(
3688                      nrows=60,
3689                      nvals=None,
3690                      uniform_row_length=None,
3691                      dtype=dtypes.int64)
3692              ],
3693              static_inner_shape=tensor_shape.TensorShape([120]),
3694              dtype=dtypes.int64),
3695          expected_row_partitions=[
3696              RowPartitionSpec(
3697                  nrows=6,
3698                  nvals=60,
3699                  uniform_row_length=None,
3700                  dtype=dtypes.int64),
3701              RowPartitionSpec(
3702                  nrows=60,
3703                  nvals=120,
3704                  uniform_row_length=None,
3705                  dtype=dtypes.int64)
3706          ],
3707          expected_static_inner_shape=tensor_shape.TensorShape([120]),
3708          expected_inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)),
3709  ])
3710  def test_constructor_improvements(self, original, expected_row_partitions,
3711                                    expected_static_inner_shape,
3712                                    expected_inner_shape):
3713    # Note that self_merge is only idempotent if no data is partially present.
3714    self.assertTensorShapeEqual(original._static_inner_shape,
3715                                expected_static_inner_shape)
3716    self.assertTensorSpecEqual(original._inner_shape, expected_inner_shape)
3717    for i, (a, b) in enumerate(
3718        zip(original._row_partitions, expected_row_partitions)):
3719      self.assertRowPartitionSpecEqual(a, b, 'Error in partition ' + str(i))
3720
3721  @parameterized.parameters([
3722      dict(
3723          row_partitions=[
3724              RowPartitionSpec(
3725                  nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64)
3726          ],
3727          static_inner_shape=tensor_shape.TensorShape([]),
3728          dtype=dtypes.int64,
3729          error_type=ValueError,
3730          msg='If row_partitions are provided, must have inner_rank > 0'),
3731      dict(
3732          row_partitions=RowPartitionSpec(
3733              nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64),
3734          static_inner_shape=tensor_shape.TensorShape([]),
3735          dtype=dtypes.int64,
3736          error_type=TypeError,
3737          msg='row_partitions should be an Iterable'),
3738      dict(
3739          row_partitions=[1, 2, 3],
3740          static_inner_shape=tensor_shape.TensorShape([12]),
3741          dtype=dtypes.int64,
3742          error_type=TypeError,
3743          msg='row_partitions should be an Iterable of RowPartitionSpecs'),
3744      dict(
3745          row_partitions=[
3746              RowPartitionSpec(
3747                  nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64)
3748          ],
3749          static_inner_shape=3,
3750          dtype=dtypes.int64,
3751          error_type=ValueError,
3752          msg='Dimensions 12 and 3'),
3753      dict(
3754          row_partitions=[
3755              RowPartitionSpec(
3756                  nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64)
3757          ],
3758          static_inner_shape=tensor_shape.TensorShape([2]),
3759          dtype=456,
3760          error_type=TypeError,
3761          msg='Cannot convert'),
3762      dict(
3763          row_partitions=[
3764              RowPartitionSpec(
3765                  nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64)
3766          ],
3767          static_inner_shape=tensor_shape.TensorShape([12]),
3768          dtype=dtypes.int32,
3769          error_type=ValueError,
3770          msg='dtype of RowPartitionSpec'),
3771      dict(
3772          row_partitions=[
3773              RowPartitionSpec(
3774                  nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64)
3775          ],
3776          static_inner_shape=tensor_shape.TensorShape([11]),
3777          dtype=dtypes.int64,
3778          error_type=ValueError,
3779          msg='Dimensions 12 and 11 are not compatible'),
3780      dict(
3781          row_partitions=[
3782              RowPartitionSpec(nvals=3, dtype=dtypes.int64),
3783              RowPartitionSpec(uniform_row_length=4, dtype=dtypes.int64),
3784              RowPartitionSpec(nrows=17, dtype=dtypes.int64),
3785          ],
3786          static_inner_shape=tensor_shape.TensorShape([20]),
3787          dtype=dtypes.int64,
3788          error_type=ValueError,
3789          msg='Dimensions 17 and 12 are not compatible'),
3790  ])
3791  def test_constructor_raises(self, row_partitions, static_inner_shape,
3792                              dtype, error_type, msg):
3793    # Note that self_merge is only idempotent if no data is partially present.
3794    with self.assertRaisesRegex(error_type, msg):
3795      dynamic_ragged_shape.DynamicRaggedShape.Spec(
3796          row_partitions=row_partitions,
3797          static_inner_shape=static_inner_shape,
3798          dtype=dtype)
3799
3800  @parameterized.parameters([
3801      # Unknown rank
3802      dict(
3803          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3804              row_partitions=[],
3805              static_inner_shape=tensor_shape.TensorShape(None),
3806              dtype=dtypes.int64),
3807          expected=tensor_shape.TensorShape(None)),
3808      # Scalar
3809      dict(
3810          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3811              row_partitions=[],
3812              static_inner_shape=tensor_shape.TensorShape([]),
3813              dtype=dtypes.int64),
3814          expected=tensor_shape.TensorShape([])),
3815      # Vector
3816      dict(
3817          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3818              row_partitions=[],
3819              static_inner_shape=tensor_shape.TensorShape([3]),
3820              dtype=dtypes.int64),
3821          expected=tensor_shape.TensorShape([3])),
3822      # Dense
3823      dict(
3824          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3825              row_partitions=[],
3826              static_inner_shape=tensor_shape.TensorShape([3, 2, None]),
3827              dtype=dtypes.int64),
3828          expected=tensor_shape.TensorShape([3, 2, None])),
3829      # Ragged
3830      dict(
3831          original=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3832              row_partitions=[
3833                  RowPartitionSpec(nrows=6,
3834                                   nvals=None,
3835                                   uniform_row_length=10,
3836                                   dtype=dtypes.int64),
3837                  RowPartitionSpec(nrows=60,
3838                                   nvals=None,
3839                                   uniform_row_length=None,
3840                                   dtype=dtypes.int64)
3841              ],
3842              static_inner_shape=tensor_shape.TensorShape([120]),
3843              dtype=dtypes.int64),
3844          expected=tensor_shape.TensorShape([6, 10, None])),
3845
3846  ])
3847  def test_to_tensor_shape(self, original, expected):
3848    # Note that self_merge is only idempotent if no data is partially present.
3849    actual = original._to_tensor_shape()
3850    self.assertEqual(actual, expected)
3851
3852  @parameterized.parameters([
3853      dict(
3854          a=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3855              row_partitions=[],
3856              static_inner_shape=tensor_shape.TensorShape([]),
3857              dtype=dtypes.int32),
3858          b=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3859              row_partitions=[],
3860              static_inner_shape=tensor_shape.TensorShape([]),
3861              dtype=dtypes.int32),
3862          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3863              row_partitions=[],
3864              static_inner_shape=tensor_shape.TensorShape([]),
3865              dtype=dtypes.int32)),
3866      dict(
3867          a=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3868              row_partitions=[],
3869              static_inner_shape=tensor_shape.TensorShape([3, None]),
3870              dtype=dtypes.int32),
3871          b=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3872              row_partitions=[],
3873              static_inner_shape=tensor_shape.TensorShape([None, 4]),
3874              dtype=dtypes.int32),
3875          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3876              row_partitions=[],
3877              static_inner_shape=tensor_shape.TensorShape([3, 4]),
3878              dtype=dtypes.int32)),
3879      dict(
3880          a=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3881              row_partitions=[
3882                  RowPartitionSpec(
3883                      nrows=6,
3884                      nvals=None,
3885                      uniform_row_length=None,
3886                      dtype=dtypes.int64)
3887              ],
3888              static_inner_shape=tensor_shape.TensorShape([None]),
3889              dtype=dtypes.int64),
3890          b=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3891              row_partitions=[
3892                  RowPartitionSpec(
3893                      nrows=6,
3894                      nvals=None,
3895                      uniform_row_length=10,
3896                      dtype=dtypes.int64)
3897              ],
3898              static_inner_shape=tensor_shape.TensorShape([None]),
3899              dtype=dtypes.int64),
3900          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3901              row_partitions=[
3902                  RowPartitionSpec(
3903                      nrows=6,
3904                      nvals=60,
3905                      uniform_row_length=10,
3906                      dtype=dtypes.int64)
3907              ],
3908              static_inner_shape=tensor_shape.TensorShape([60]),
3909              dtype=dtypes.int64)),
3910      dict(
3911          a=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3912              row_partitions=[
3913                  RowPartitionSpec(
3914                      nrows=6,
3915                      nvals=None,
3916                      uniform_row_length=None,
3917                      dtype=dtypes.int64)
3918              ],
3919              static_inner_shape=tensor_shape.TensorShape([None]),
3920              dtype=dtypes.int64),
3921          b=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3922              row_partitions=[],
3923              static_inner_shape=tensor_shape.TensorShape([None, 10]),
3924              dtype=dtypes.int64),
3925          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3926              row_partitions=[
3927                  RowPartitionSpec(
3928                      nrows=6,
3929                      nvals=60,
3930                      uniform_row_length=10,
3931                      dtype=dtypes.int64)
3932              ],
3933              static_inner_shape=tensor_shape.TensorShape([60]),
3934              dtype=dtypes.int64))
3935  ])
3936  def test_merge_with(self,
3937                      a: DynamicRaggedShape.Spec,
3938                      b: DynamicRaggedShape.Spec,
3939                      expected: DynamicRaggedShape.Spec):
3940    actual = a._merge_with(b)
3941    actual_rev = b._merge_with(a)
3942
3943    self.assertDynamicRaggedShapeSpecEqual(actual, expected)
3944    self.assertDynamicRaggedShapeSpecEqual(actual_rev, expected)
3945
3946  @parameterized.parameters([
3947      dict(
3948          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3949              row_partitions=[
3950                  RowPartitionSpec(
3951                      nrows=6,
3952                      nvals=3,
3953                      uniform_row_length=None,
3954                      dtype=dtypes.int64)
3955              ],
3956              static_inner_shape=tensor_shape.TensorShape([3]),
3957              dtype=dtypes.int64),
3958          batch_size=3,
3959          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3960              row_partitions=[
3961                  RowPartitionSpec(
3962                      nrows=3,
3963                      nvals=18,
3964                      uniform_row_length=6,
3965                      dtype=dtypes.int64),
3966                  RowPartitionSpec(
3967                      nrows=18,
3968                      nvals=9,
3969                      uniform_row_length=None,
3970                      dtype=dtypes.int64)
3971              ],
3972              static_inner_shape=tensor_shape.TensorShape([9]),
3973              dtype=dtypes.int64)),
3974      dict(
3975          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3976              row_partitions=[
3977                  RowPartitionSpec(
3978                      nrows=None,
3979                      nvals=3,
3980                      uniform_row_length=None,
3981                      dtype=dtypes.int64)
3982              ],
3983              static_inner_shape=tensor_shape.TensorShape([3]),
3984              dtype=dtypes.int64),
3985          batch_size=3,
3986          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
3987              row_partitions=[
3988                  RowPartitionSpec(
3989                      nrows=3,
3990                      nvals=None,
3991                      uniform_row_length=None,
3992                      dtype=dtypes.int64),
3993                  RowPartitionSpec(
3994                      nrows=None,
3995                      nvals=9,
3996                      uniform_row_length=None,
3997                      dtype=dtypes.int64)
3998              ],
3999              static_inner_shape=tensor_shape.TensorShape([9]),
4000              dtype=dtypes.int64)),
4001      dict(
4002          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4003              row_partitions=[
4004                  RowPartitionSpec(
4005                      nrows=None,
4006                      nvals=None,
4007                      uniform_row_length=None,
4008                      dtype=dtypes.int64)
4009              ],
4010              static_inner_shape=tensor_shape.TensorShape([None]),
4011              dtype=dtypes.int64),
4012          batch_size=3,
4013          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4014              row_partitions=[
4015                  RowPartitionSpec(
4016                      nrows=3,
4017                      nvals=None,
4018                      uniform_row_length=None,
4019                      dtype=dtypes.int64),
4020                  RowPartitionSpec(
4021                      nrows=None,
4022                      nvals=None,
4023                      uniform_row_length=None,
4024                      dtype=dtypes.int64)
4025              ],
4026              static_inner_shape=tensor_shape.TensorShape([None]),
4027              dtype=dtypes.int64)),
4028      dict(
4029          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4030              row_partitions=[
4031                  RowPartitionSpec(
4032                      nrows=None,
4033                      nvals=None,
4034                      uniform_row_length=None,
4035                      dtype=dtypes.int64)
4036              ],
4037              static_inner_shape=tensor_shape.TensorShape(None),
4038              dtype=dtypes.int64),
4039          batch_size=3,
4040          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4041              row_partitions=[
4042                  RowPartitionSpec(
4043                      nrows=3,
4044                      nvals=None,
4045                      uniform_row_length=None,
4046                      dtype=dtypes.int64),
4047                  RowPartitionSpec(
4048                      nrows=None,
4049                      nvals=None,
4050                      uniform_row_length=None,
4051                      dtype=dtypes.int64)
4052              ],
4053              static_inner_shape=tensor_shape.TensorShape(None),
4054              dtype=dtypes.int64)),
4055      dict(
4056          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4057              row_partitions=[
4058                  RowPartitionSpec(
4059                      nrows=None,
4060                      nvals=6,
4061                      uniform_row_length=None,
4062                      dtype=dtypes.int64)
4063              ],
4064              static_inner_shape=tensor_shape.TensorShape([6, 4]),
4065              dtype=dtypes.int64),
4066          batch_size=3,
4067          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4068              row_partitions=[
4069                  RowPartitionSpec(
4070                      nrows=3,
4071                      nvals=None,
4072                      uniform_row_length=None,
4073                      dtype=dtypes.int64),
4074                  RowPartitionSpec(
4075                      nrows=None,
4076                      nvals=18,
4077                      uniform_row_length=None,
4078                      dtype=dtypes.int64)
4079              ],
4080              static_inner_shape=tensor_shape.TensorShape([18, 4]),
4081              dtype=dtypes.int64)),
4082      dict(
4083          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4084              row_partitions=[],
4085              static_inner_shape=tensor_shape.TensorShape(None),
4086              dtype=dtypes.int32),
4087          batch_size=3,
4088          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4089              row_partitions=[],
4090              static_inner_shape=tensor_shape.TensorShape(None),
4091              dtype=dtypes.int32)),
4092      dict(
4093          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4094              row_partitions=[],
4095              static_inner_shape=tensor_shape.TensorShape([8, 9]),
4096              dtype=dtypes.int32),
4097          batch_size=7,
4098          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4099              row_partitions=[],
4100              static_inner_shape=tensor_shape.TensorShape([7, 8, 9]),
4101              dtype=dtypes.int32)),
4102  ])
4103  def test_batch(self,
4104                 spec: DynamicRaggedShape.Spec,
4105                 batch_size: int,
4106                 expected: DynamicRaggedShape.Spec):
4107    encoder = dynamic_ragged_shape._DynamicRaggedShapeBatchEncoder()
4108    actual = encoder.batch(spec, batch_size)
4109    self.assertDynamicRaggedShapeSpecEqual(actual, expected)
4110
4111  @parameterized.parameters([
4112      dict(
4113          spec=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4114              row_partitions=[
4115                  RowPartitionSpec(
4116                      nrows=6,
4117                      nvals=3,
4118                      uniform_row_length=None,
4119                      dtype=dtypes.int32)],
4120              static_inner_shape=tensor_shape.TensorShape([3]),
4121              dtype=dtypes.int32),
4122          expected=dynamic_ragged_shape.DynamicRaggedShape.Spec(
4123              row_partitions=[],
4124              static_inner_shape=tensor_shape.TensorShape([None]),
4125              dtype=dtypes.int32))
4126  ])
4127  def test_unbatch(self, spec: DynamicRaggedShape.Spec,
4128                   expected: DynamicRaggedShape.Spec):
4129    encoder = dynamic_ragged_shape._DynamicRaggedShapeBatchEncoder()
4130    actual = encoder.unbatch(spec)
4131    self.assertDynamicRaggedShapeSpecEqual(actual, expected)
4132
4133  def test_repr(self):
4134    original = dynamic_ragged_shape.DynamicRaggedShape.Spec(
4135        row_partitions=[
4136            RowPartitionSpec(
4137                nrows=6,
4138                nvals=None,
4139                uniform_row_length=None,
4140                dtype=dtypes.int64)
4141        ],
4142        static_inner_shape=tensor_shape.TensorShape([None]),
4143        dtype=dtypes.int64)
4144    representation = repr(original)
4145    static_inner_shape = tensor_shape.TensorShape([None])
4146    expected = ('DynamicRaggedShape.Spec(' +
4147                'row_partitions=(RowPartitionSpec(' +
4148                'nrows=6, nvals=None, uniform_row_length=None, ' +
4149                'dtype=tf.int64),), ' +
4150                f'static_inner_shape={static_inner_shape!r}, ' +
4151                'dtype=tf.int64)')
4152    self.assertEqual(representation, expected)
4153
4154  @parameterized.parameters([
4155      dict(
4156          lengths=[3, 4, 5],
4157          expected=DynamicRaggedShape.Spec(
4158              row_partitions=[],
4159              static_inner_shape=tensor_shape.TensorShape([3, 4, 5]),
4160              dtype=dtypes.int64)),
4161      dict(
4162          lengths=[2, (4, 1), 5],
4163          expected=DynamicRaggedShape.Spec(
4164              row_partitions=[RowPartitionSpec(nrows=2, nvals=5)],
4165              static_inner_shape=tensor_shape.TensorShape([5, 5]),
4166              dtype=dtypes.int64)),
4167      dict(
4168          lengths=[2, (4, 1), 5],
4169          dtype=dtypes.int32,
4170          expected=DynamicRaggedShape.Spec(
4171              row_partitions=[
4172                  RowPartitionSpec(nrows=2, nvals=5, dtype=dtypes.int32)],
4173              static_inner_shape=tensor_shape.TensorShape([5, 5]),
4174              dtype=dtypes.int32)),
4175  ])
4176  def test_from_value(self, lengths, expected, dtype=None):
4177    original = DynamicRaggedShape.from_lengths(lengths)
4178    if dtype is not None:
4179      original = original.with_dtype(dtype)
4180    actual = dynamic_ragged_shape.DynamicRaggedShape.Spec.from_value(original)
4181    self.assertTensorShapeEqual(actual, expected)
4182
4183if __name__ == '__main__':
4184  googletest.main()
4185