1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.ragged.dynamic_ragged_shape.""" 16 17from typing import Sequence, Union 18 19from absl.testing import parameterized 20import numpy as np 21 22from tensorflow.python.client import session 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.eager import context 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors_impl 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.framework import test_util 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import gen_math_ops 36from tensorflow.python.ops import gradients_impl 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import string_ops 39from tensorflow.python.ops.ragged import dynamic_ragged_shape 40from tensorflow.python.ops.ragged import ragged_array_ops 41from tensorflow.python.ops.ragged import ragged_factory_ops 42from tensorflow.python.ops.ragged import ragged_tensor 43from tensorflow.python.ops.ragged.dynamic_ragged_shape import _LayerBroadcaster 44from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape 45from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor 46from tensorflow.python.ops.ragged.row_partition import RowPartition 47from tensorflow.python.ops.ragged.row_partition import RowPartitionSpec 48from tensorflow.python.platform import googletest 49 50 51def _to_row_partitions_from_lengths( 52 lengths: Sequence[Union[int, Sequence[int]]]) -> Sequence[RowPartition]: 53 """Allow ragged and uniform shapes to be specified. 54 55 For example, [2, [2,1], 2] represents a shape like: 56 [[[0, 0], [0, 0]], [[0, 0]]] 57 58 Args: 59 lengths: a list of integers and lists of integers. 60 61 Returns: 62 a sequence of RowPartitions. 63 """ 64 (result, 65 _) = dynamic_ragged_shape._to_row_partitions_and_nvals_from_lengths(lengths) 66 return result 67 68 69def _to_ragged_tensor_from_lengths( 70 values, lengths: Sequence[Union[int, Sequence[int]]]) -> RaggedTensor: 71 """Specify a ragged tensor (or tensor) from lengths and values.""" 72 row_partitions = _to_row_partitions_from_lengths(lengths) 73 values = constant_op.constant(values) 74 if not row_partitions: 75 return values 76 return RaggedTensor._from_nested_row_partitions(values, row_partitions) 77 78 79def _divides(a, b): 80 return b % a == 0 81 82 83def _next_prime(primes_so_far): 84 first_candidate = 2 85 if primes_so_far: 86 first_candidate = primes_so_far[-1] + 1 87 while True: 88 if not any([_divides(x, first_candidate) for x in primes_so_far]): 89 return first_candidate 90 first_candidate = first_candidate + 1 91 92 93def _lowest_primes(n): 94 """Give the lowest n primes.""" 95 result = [] 96 for _ in range(n): 97 result.append(_next_prime(result)) 98 return result 99 100 101def _num_elements_of_lengths_with_rows(rows, 102 lengths: Sequence[Union[int, 103 Sequence[int]]]): 104 """Helper function for _num_elements_of_lengths.""" 105 if not lengths: 106 return rows 107 next_length = lengths[0] 108 if isinstance(next_length, int): 109 return _num_elements_of_lengths_with_rows(next_length * rows, lengths[1:]) 110 else: 111 return _num_elements_of_lengths_with_rows(sum(next_length), lengths[1:]) 112 113 114def _num_elements_of_lengths(lengths: Sequence[Union[int, Sequence[int]]]): 115 """Static version of DynamicRaggedShape.from_lengths(lengths)._num_elements().""" 116 return _num_elements_of_lengths_with_rows(1, lengths) 117 118 119def _to_prime_tensor_from_lengths( 120 lengths: Sequence[Union[int, Sequence[int]]]) -> RaggedTensor: 121 """Create a tensor of primes with the shape specified.""" 122 shape = DynamicRaggedShape.from_lengths(lengths) 123 num_elements = _num_elements_of_lengths(lengths) 124 return ragged_array_ops.ragged_reshape(_lowest_primes(num_elements), shape) 125 126 127@test_util.run_all_in_graph_and_eager_modes 128class DynamicRaggedShapeTest(test_util.TensorFlowTestCase, 129 parameterized.TestCase): 130 131 def assertRowPartitionEq(self, 132 x: RowPartition, 133 y: RowPartition, 134 msg=None) -> None: 135 self.assertAllEqual(x.row_splits(), y.row_splits(), msg=msg) 136 137 def assertShapeEq(self, 138 x: DynamicRaggedShape, 139 y: DynamicRaggedShape, 140 msg=None) -> None: 141 assert isinstance(x, DynamicRaggedShape) 142 assert isinstance(y, DynamicRaggedShape) 143 if msg is None: 144 msg = '' 145 self.assertLen( 146 x.row_partitions, len(y.row_partitions), msg=msg + ': length unequal') 147 for i in range(len(x.row_partitions)): 148 x_dims = x.row_partitions[i] 149 y_dims = y.row_partitions[i] 150 self.assertRowPartitionEq( 151 x_dims, y_dims, msg=msg + ': row_partition ' + str(i)) 152 self.assertAllEqual( 153 x.inner_shape, y.inner_shape, msg=msg + ': shapes unequal') 154 155 def assertLayerBroadcasterEq(self, x: _LayerBroadcaster, 156 y: _LayerBroadcaster) -> None: 157 assert isinstance(x, _LayerBroadcaster) 158 assert isinstance(y, _LayerBroadcaster) 159 self.assertAllEqual(x.gather_index, y.gather_index) 160 161 def assertBroadcasterEq(self, x: dynamic_ragged_shape._Broadcaster, 162 y: dynamic_ragged_shape._Broadcaster) -> None: 163 assert isinstance(x, dynamic_ragged_shape._Broadcaster) 164 assert isinstance(y, dynamic_ragged_shape._Broadcaster) 165 self.assertShapeEq(x.source_shape, y.source_shape) 166 self.assertShapeEq(x.target_shape, y.target_shape) 167 self.assertLen(x._layer_broadcasters, len(y._layer_broadcasters)) 168 for x_layer, y_layer in zip(x._layer_broadcasters, y._layer_broadcasters): 169 self.assertLayerBroadcasterEq(x_layer, y_layer) 170 171 @parameterized.parameters([ 172 dict(value='x', row_partitions=[], inner_shape=()), 173 dict(value=['a', 'b', 'c'], row_partitions=[], inner_shape=[3]), 174 dict( 175 value=[['a', 'b', 'c'], ['d', 'e', 'f']], 176 row_partitions=(), 177 inner_shape=[2, 3]), 178 dict( 179 value=[[['a', 'b', 'c'], ['d', 'e', 'f']]], 180 row_partitions=(), 181 inner_shape=[1, 2, 3]), 182 dict( 183 value=ragged_factory_ops.constant_value([['a', 'b', 'c'], ['d', 'e']], 184 ragged_rank=1), 185 row_partitions=[[0, 3, 5]], 186 inner_shape=[5]), 187 dict( 188 value=ragged_factory_ops.constant_value( 189 [[['a', 'b', 'c'], ['d', 'e', 'f']]], ragged_rank=1), 190 row_partitions=[[0, 2]], 191 inner_shape=[2, 3]), 192 dict( 193 value=ragged_factory_ops.constant_value( 194 [[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1), 195 row_partitions=[[0, 2, 3]], 196 inner_shape=[3, 2, 1]), 197 dict( 198 value=ragged_factory_ops.constant_value([[10, 20], [30]]), 199 row_partitions=[[0, 2, 3]], 200 inner_shape=[3]), 201 # Docstring examples: 202 dict(value=[[1, 2, 3], [4, 5, 6]], row_partitions=[], inner_shape=[2, 3]), 203 dict( 204 value=ragged_factory_ops.constant_value([[1, 2], [], [3, 4, 5]]), 205 row_partitions=[[0, 2, 2, 5]], 206 inner_shape=[5]), 207 dict( 208 value=ragged_factory_ops.constant_value([[[1, 2], [3, 4]], [[5, 6]]], 209 ragged_rank=1), 210 row_partitions=[[0, 2, 3]], 211 inner_shape=[3, 2]), 212 dict( 213 value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]), 214 row_partitions=[[0, 2, 3], [0, 2, 3, 5]], 215 inner_shape=[5]), 216 ]) 217 def testFromTensor(self, value, row_partitions, inner_shape): 218 shape = DynamicRaggedShape.from_tensor(value) 219 row_partitions = [RowPartition.from_row_splits(x) for x in row_partitions] 220 expected = DynamicRaggedShape(row_partitions, inner_shape) 221 self.assertShapeEq(shape, expected) 222 223 # pylint:disable=g-long-lambda 224 @parameterized.parameters([ 225 # from_lengths | row_partitions | inner_shape 226 # ---------------------- | --------------------------| ------------- 227 # [] | [] | [] 228 # [2, (3, 2)] | [RP([3, 2])] | [5] 229 # [2, 2] | [] | [2, 2] 230 # [2, (3, 2), 7] | [RP([3, 2])] | [5, 7] 231 # [2, (2, 2), 3] | [RP([2, 2])] | [4, 3] 232 # [2, 2, 3] | [] | [2, 2, 3] 233 # [2, (2, 1), (2, 0, 3)] | [RP(2, 1), RP([2, 0, 3])] | [5] 234 235 dict(lengths=[], row_partitions=[], inner_shape=[]), 236 dict( 237 lengths=[2, (3, 2)], 238 row_partitions=lambda: [RowPartition.from_row_lengths([3, 2])], 239 inner_shape=[5]), 240 dict(lengths=[2, 2], row_partitions=[], inner_shape=[2, 2]), 241 dict( 242 lengths=[2, (3, 2), 7], 243 row_partitions=lambda: [RowPartition.from_row_lengths([3, 2])], 244 inner_shape=[5, 7]), 245 dict( 246 lengths=[2, (2, 2), 3], 247 row_partitions=lambda: [RowPartition.from_row_lengths([2, 2])], 248 inner_shape=[4, 3]), 249 dict(lengths=[2, 2, 3], row_partitions=[], inner_shape=[2, 2, 3]), 250 dict( 251 lengths=[2, (2, 1), (2, 0, 3)], 252 row_partitions=lambda: [ 253 RowPartition.from_row_lengths([2, 1]), 254 RowPartition.from_row_lengths([2, 0, 3]) 255 ], 256 inner_shape=[5]), 257 # from_lengths | num_row | row_partitions | inner_shape 258 # : partitions : : 259 # ---------------| -----------|--------------------------|------------ 260 # [2, (3, 2), 2] | 2 | [RP([3, 2]), URP(2, 10)] | [10] 261 # [2, 2] | 1 | [URP(2, 4)] | [4] 262 # [2, 2, 3] | 0 | [] | [2, 2, 3] 263 # [2, 2, 3] | 1 | [URP(2, 4)] | [4, 3] 264 # [2, 2, 3] | 2 | [URP(2, 4), URP(3, 12)] | [12] 265 dict(lengths=[2, (3, 2), 2], 266 num_row_partitions=2, 267 row_partitions=lambda: [RowPartition.from_row_lengths([3, 2]), 268 RowPartition.from_uniform_row_length(2, 10)], 269 inner_shape=[10]), 270 dict(lengths=[2, 2], 271 num_row_partitions=1, 272 row_partitions=lambda: [RowPartition.from_uniform_row_length(2, 4)], 273 inner_shape=[4]), 274 dict(lengths=[2, 2, 3], 275 num_row_partitions=0, 276 row_partitions=[], 277 inner_shape=[2, 2, 3]), 278 dict(lengths=[2, 2, 3], 279 num_row_partitions=1, 280 row_partitions=lambda: [RowPartition.from_uniform_row_length(2, 4)], 281 inner_shape=[4, 3]), 282 dict(lengths=[2, 2, 3], 283 num_row_partitions=2, 284 row_partitions=lambda: [RowPartition.from_uniform_row_length(2, 4), 285 RowPartition.from_uniform_row_length(3, 12)], 286 inner_shape=[12]) 287 ]) 288 def testFromLengths(self, 289 lengths, 290 row_partitions, 291 inner_shape, 292 num_row_partitions=None): 293 if callable(row_partitions): 294 row_partitions = row_partitions() 295 shape = DynamicRaggedShape.from_lengths( 296 lengths, num_row_partitions=num_row_partitions) 297 expected = DynamicRaggedShape(row_partitions, inner_shape) 298 self.assertShapeEq(shape, expected) 299 300 @parameterized.parameters([ 301 dict( 302 lengths=[2, (2, 1, 3)], 303 num_row_partitions=1, 304 msg='Shape not consistent'), 305 dict( 306 lengths=[2, 3], 307 num_row_partitions=2, 308 msg='num_row_partitions should be less than'), 309 dict( 310 lengths=[], 311 num_row_partitions=3, 312 msg='num_row_partitions==0 for a scalar shape'), 313 dict( 314 lengths=[(5, 3), 3], 315 num_row_partitions='a', 316 msg='num_row_partitions should be an int or None'), 317 dict( 318 lengths=[(5, 'a'), 3], 319 num_row_partitions=0, 320 msg='element of lengths should be int or tuple of ints'), 321 dict( 322 lengths=['a'], 323 num_row_partitions=0, 324 msg='element of lengths should be int or tuple of ints'), 325 dict(lengths=7, num_row_partitions=0, msg='lengths should be a list') 326 ]) 327 def testFromLengthsError(self, lengths, msg, num_row_partitions=None): 328 with self.assertRaisesRegex(ValueError, msg): 329 DynamicRaggedShape.from_lengths( 330 lengths, num_row_partitions=num_row_partitions) 331 332 def testGetItemSliceRankUnknownA(self): 333 if not context.executing_eagerly(): 334 original_t = array_ops.placeholder_with_default(np.array([4, 5, 3]), None) 335 sh = DynamicRaggedShape.from_tensor(original_t) 336 known = sh[:1] 337 self.assertIsNone(known.rank) 338 339 def testGetItemSliceRankUnknownLong(self): 340 if not context.executing_eagerly(): 341 original_t = array_ops.placeholder_with_default(np.array([4, 5, 3]), None) 342 sh = DynamicRaggedShape.from_tensor(original_t) 343 unknown = sh[:20] 344 self.assertIsNone(unknown.rank) 345 346 def testGetItemSliceRankKnownLong(self): 347 if not context.executing_eagerly(): 348 original_t = constant_op.constant([4, 5, 3], dtypes.float32) 349 sh = DynamicRaggedShape.from_tensor(original_t) 350 unknown = sh[:20] 351 self.assertEqual(unknown.rank, 1) 352 353 def testGetBroadcaster(self): 354 origin_shape = DynamicRaggedShape( 355 [RowPartition.from_uniform_row_length(1, 3)], inner_shape=[3]) 356 dest_shape = DynamicRaggedShape( 357 [RowPartition.from_uniform_row_length(2, 6)], inner_shape=[6]) 358 actual = dynamic_ragged_shape._get_broadcaster(origin_shape, dest_shape) 359 expected = dynamic_ragged_shape._Broadcaster(origin_shape, dest_shape, [ 360 _LayerBroadcaster.from_gather_index([0, 1, 2]), 361 _LayerBroadcaster.from_gather_index([0, 0, 1, 1, 2, 2]) 362 ]) 363 self.assertBroadcasterEq(actual, expected) 364 365 def testGetBroadcaster2(self): 366 origin_shape = DynamicRaggedShape([], inner_shape=[]) 367 dest_shape = DynamicRaggedShape([RowPartition.from_row_splits([0, 2, 3])], 368 inner_shape=[3]) 369 actual = dynamic_ragged_shape._get_broadcaster(origin_shape, dest_shape) 370 expected = dynamic_ragged_shape._Broadcaster(origin_shape, dest_shape, []) 371 self.assertBroadcasterEq(actual, expected) 372 373 @parameterized.parameters([ 374 dict(lengths=[2, 3], axis=0, expected=2), 375 dict(lengths=[2, 3], axis=1, expected=6), 376 dict(lengths=[2, 3], axis=-1, expected=6), 377 dict(lengths=[2, 3], axis=-2, expected=2), 378 dict(lengths=[2, 3, 4], axis=0, expected=2), 379 dict(lengths=[2, 3, 4], axis=1, expected=6), 380 dict(lengths=[2, 3, 4], axis=2, expected=24), 381 dict(lengths=[2, 3, 4], axis=-1, expected=24), 382 dict(lengths=[2, 3, 4], axis=-2, expected=6), 383 dict(lengths=[2, 3, 4], axis=-3, expected=2), 384 dict(lengths=[2, (2, 3), 7], axis=0, expected=2), 385 dict(lengths=[2, (2, 3), 7], axis=1, expected=5), 386 dict(lengths=[2, (2, 3), 7], axis=2, expected=35), 387 dict(lengths=[2, (2, 3), 7], axis=-1, expected=35), 388 dict(lengths=[2, (2, 3), 7], axis=-2, expected=5), 389 dict(lengths=[2, (2, 3), 7], axis=-3, expected=2), 390 ]) 391 def testNumSlicesInDimension(self, lengths, axis, expected): 392 original = DynamicRaggedShape.from_lengths(lengths) 393 actual = original._num_slices_in_dimension(axis) 394 self.assertAllEqual(expected, actual) 395 396 @parameterized.parameters([ 397 dict( 398 lengths=[2, 3], 399 axis=0.5, 400 error_type=TypeError, 401 error_regex='axis must be an integer'), 402 ]) 403 def testNumSlicesInDimensionRaises(self, lengths, axis, error_type, 404 error_regex): 405 original = DynamicRaggedShape.from_lengths(lengths) 406 with self.assertRaisesRegex(error_type, error_regex): 407 original._num_slices_in_dimension(axis) 408 409 @parameterized.parameters([ 410 dict( 411 lengths=[2, (1, 2), 4], 412 new_dense_rank=3, 413 error_type=ValueError, 414 error_regex='Cannot get an inner shape'), 415 dict( 416 lengths=[], 417 new_dense_rank=3, 418 error_type=ValueError, 419 error_regex='old inner_rank cannot be zero'), 420 dict( 421 lengths=[2, 3], 422 new_dense_rank=0, 423 error_type=ValueError, 424 error_regex='new_inner_rank cannot be zero'), 425 ]) 426 def testAltInnerShapeRaises(self, lengths, new_dense_rank, error_type, 427 error_regex): 428 original = DynamicRaggedShape.from_lengths(lengths) 429 with self.assertRaisesRegex(error_type, error_regex): 430 original._alt_inner_shape(new_dense_rank) 431 432 @parameterized.parameters([ 433 dict( 434 lengths=[2, (1, 2), 4], new_dense_rank=2, expected_inner_shape=[3, 435 4]), 436 ]) 437 def testAltInnerShape(self, lengths, new_dense_rank, expected_inner_shape): 438 original = DynamicRaggedShape.from_lengths(lengths) 439 actual = original._alt_inner_shape(new_dense_rank) 440 self.assertAllEqual(actual, expected_inner_shape) 441 442 def testWithNumRowPartitionsDynamic(self): 443 @def_function.function( 444 input_signature=[tensor_spec.TensorSpec([3], dtypes.int64)]) 445 def fun(x): 446 shape = DynamicRaggedShape([ 447 RowPartition.from_row_lengths([1, 3], dtype=dtypes.int64), 448 RowPartition.from_row_lengths([2, 3, 4, 5], dtype=dtypes.int64) 449 ], x) 450 result = shape._with_num_row_partitions(3) 451 expected = DynamicRaggedShape([ 452 RowPartition.from_row_lengths([1, 3], dtype=dtypes.int64), 453 RowPartition.from_row_lengths([2, 3, 4, 5], dtype=dtypes.int64), 454 RowPartition.from_uniform_row_length( 455 2, nrows=14, nvals=28, dtype=dtypes.int64) 456 ], [14 * 2, 3]) 457 self.assertShapeEq(expected, result) 458 fun(constant_op.constant([14, 2, 3], dtype=dtypes.int64)) 459 460 @parameterized.parameters([ 461 dict( 462 lengths=[2], 463 new_dense_rank=2, 464 error_type=ValueError, 465 error_regex='Cannot change inner_rank if'), 466 ]) 467 def testWithDenseRankRaises(self, lengths, new_dense_rank, error_type, 468 error_regex): 469 original = DynamicRaggedShape.from_lengths(lengths) 470 with self.assertRaisesRegex(error_type, error_regex): 471 original._with_inner_rank(new_dense_rank) 472 473 @parameterized.parameters([ 474 dict( 475 lengths=[2, (1, 2)], 476 num_row_partitions=2, 477 error_type=ValueError, 478 error_regex='num_row_partitions must be less than rank'), 479 dict( 480 lengths=[2], 481 num_row_partitions=-1, 482 error_type=ValueError, 483 error_regex='num_row_partitions must be nonnegative'), 484 dict( 485 lengths=[2], 486 num_row_partitions=0.5, 487 error_type=ValueError, 488 error_regex='num_row_partitions must be an int'), 489 ]) 490 def testWithNumRowPartitionsRaises(self, lengths, num_row_partitions, 491 error_type, error_regex): 492 original = DynamicRaggedShape.from_lengths(lengths) 493 with self.assertRaisesRegex(error_type, error_regex): 494 original._with_num_row_partitions(num_row_partitions) 495 496 def testDimensionRaises(self): 497 original = DynamicRaggedShape.from_lengths([2, (1, 2)]) 498 with self.assertRaisesRegex(TypeError, 'index should be an int'): 499 # This error is not exposed directly to the end user. 500 original._dimension(0.5) 501 502 @parameterized.parameters([ 503 # The whole shape (num_row_partitions=0, start=negative, stop=really big) 504 dict(lengths=[2, 3], s=slice(-1000, 100), expected_lengths=[2, 3]), 505 # The whole shape (num_row_partitions=0, stop=really big) 506 dict(lengths=[2, 3], s=slice(0, 100), expected_lengths=[2, 3]), 507 # The whole shape (num_row_partitions=0, stop=None) 508 dict(lengths=[2, 3], s=slice(0, None), expected_lengths=[2, 3]), 509 # start = None, num_row_partitions=1, stop = 3 < rank = 4 510 dict( 511 lengths=[2, (1, 2), 3, 4], 512 s=slice(None, 3), 513 expected_lengths=[2, (1, 2), 3]), 514 # start = 1, num_row_partitions=1, stop = 4, rank = 4 515 dict( 516 lengths=[2, 3, 3, 4], 517 num_row_partitions=1, 518 s=slice(1, 4), 519 expected_lengths=[3, 3, 4]), 520 # start = 1, num_row_partitions=1, stop = 3 < rank = 4 521 dict( 522 lengths=[2, 3, 3, 4], 523 num_row_partitions=1, 524 s=slice(1, 3), 525 expected_lengths=[3, 3]), 526 # start = 1, num_row_partitions=2, stop = 3 < rank = 4 527 dict( 528 lengths=[2, 3, 4, 3, 4], 529 num_row_partitions=2, 530 s=slice(1, 3), 531 expected_lengths=[3, 4]), 532 # start = 0, num_row_partitions=1, stop = 3 < rank = 4 533 dict( 534 lengths=[2, (1, 2), 3, 4], 535 s=slice(0, 3), 536 expected_lengths=[2, (1, 2), 3]), 537 # start = 0, num_row_partitions=0, stop < rank 538 dict(lengths=[2, 3, 4], s=slice(0, 2), expected_lengths=[2, 3]), 539 # start=0 < stop=2 <= num_row_partitions 540 dict( 541 lengths=[2, (1, 2), (3, 4, 5)], 542 s=slice(0, 2), 543 expected_lengths=[2, (1, 2)]), 544 # start=0 < stop=1 <= num_row_partitions 545 dict(lengths=[2, (1, 2), (3, 4, 5)], s=slice(0, 1), expected_lengths=[2]), 546 # Reversed indices, gives scalar shape. 547 dict(lengths=[2, 3], s=slice(2, 0), expected_lengths=[]), 548 # The whole shape (num_row_partitions=0) 549 dict(lengths=[2, 3], s=slice(0, 2), expected_lengths=[2, 3]), 550 ]) 551 def testGetItemSlice(self, 552 lengths, 553 s, 554 expected_lengths, 555 num_row_partitions=None): 556 original = DynamicRaggedShape.from_lengths(lengths) 557 if num_row_partitions is not None: 558 original = original._with_num_row_partitions(num_row_partitions) 559 expected = DynamicRaggedShape.from_lengths(expected_lengths) 560 actual = original[s] 561 self.assertShapeEq(expected, actual) 562 563 @parameterized.parameters([ 564 dict( 565 lengths=[2, (1, 2), 3, 4], 566 index=0.5, 567 error_type=TypeError, 568 error_regex='Argument is not an int or a slice'), 569 dict( 570 lengths=[2, (1, 2), 3, 4], 571 index=slice(0, 1, 2), 572 error_type=IndexError, 573 error_regex='Cannot stride through a shape'), 574 dict( 575 lengths=[2, (1, 2), 3, 4], 576 index=1, 577 error_type=ValueError, 578 error_regex='Index 1 is not uniform'), 579 dict( 580 lengths=[2, 3, 3, 4], 581 num_row_partitions=1, 582 index=-20, 583 error_type=IndexError, 584 error_regex='Index must be non-negative'), 585 dict( 586 lengths=[2, 3, 3, 4], 587 num_row_partitions=1, 588 index=9, 589 error_type=IndexError, 590 error_regex='Index is too big'), 591 ]) 592 def testGetItemRaisesStatic(self, 593 lengths, 594 index, 595 error_type, 596 error_regex, 597 num_row_partitions=None): 598 original = DynamicRaggedShape.from_lengths(lengths) 599 if num_row_partitions is not None: 600 original = original._with_num_row_partitions(num_row_partitions) 601 with self.assertRaisesRegex(error_type, error_regex): 602 original[index] # pylint: disable=pointless-statement 603 604 def testBroadcastToAlt(self): 605 origin = RaggedTensor.from_uniform_row_length([3, 4, 5], 606 uniform_row_length=1) 607 expected = RaggedTensor.from_uniform_row_length([3, 3, 4, 4, 5, 5], 608 uniform_row_length=2) 609 expected_shape = DynamicRaggedShape.from_tensor(expected) 610 actual = dynamic_ragged_shape.broadcast_to(origin, expected_shape) 611 self.assertAllEqual(actual, expected) 612 613 @parameterized.parameters([ 614 dict( 615 source_lengths=[3], 616 target_lengths=[1, 3], 617 target_num_row_partitions=1, 618 expected_gather_indices=[[0, 1, 2]]), 619 dict( # BroadcastTensorTo4 broadcaster. 620 source_lengths=[2, 3], 621 target_lengths=[1, 2, 3], 622 target_num_row_partitions=2, 623 expected_gather_indices=[[0, 1], [0, 1, 2, 3, 4, 5]]), 624 dict( # raggedTensor1. 625 source_lengths=[3, (1, 2, 1), 2, 2], 626 source_num_row_partitions=3, 627 target_lengths=[1, 1, 3, (1, 2, 1), 2, 2], 628 target_num_row_partitions=5, 629 expected_gather_indices=[[0, 1, 2], [0, 1, 2, 3], 630 [0, 1, 2, 3, 4, 5, 6, 7], 631 [ 632 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 633 13, 14, 15 634 ]]), 635 ]) 636 def testBroadcaster(self, 637 source_lengths, 638 target_lengths, 639 expected_gather_indices, 640 source_num_row_partitions=None, 641 target_num_row_partitions=None): 642 source = DynamicRaggedShape.from_lengths(source_lengths) 643 if source_num_row_partitions is not None: 644 source = source._with_num_row_partitions(source_num_row_partitions) 645 target = DynamicRaggedShape.from_lengths(target_lengths) 646 if target_num_row_partitions is not None: 647 target = target._with_num_row_partitions(target_num_row_partitions) 648 649 expected_gather_indices = [ 650 _LayerBroadcaster.from_gather_index(x) for x in expected_gather_indices 651 ] 652 actual = dynamic_ragged_shape._get_broadcaster(source, target) 653 expected = dynamic_ragged_shape._Broadcaster(source, target, 654 expected_gather_indices) 655 self.assertBroadcasterEq(actual, expected) 656 657 def testRaggedGradientSimple1(self): 658 if context.executing_eagerly(): 659 return 660 def func(x): 661 rt1 = RaggedTensor.from_row_splits( 662 values=x, row_splits=[0, 4, 7, 8], validate=False) 663 rt2 = rt1 * [[10], [100], [1000]] 664 return rt2.flat_values 665 666 x = constant_op.constant([3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0]) 667 y = func(x) 668 g = gradients_impl.gradients(ys=y, xs=x)[0] 669 670 self.assertAllClose(ops.convert_to_tensor(g), 671 [10., 10., 10., 10., 100., 100., 100, 1000.]) 672 673 def testRaggedGradientSimple2(self): 674 if context.executing_eagerly(): 675 return 676 def func(x): 677 rt1 = RaggedTensor._from_row_partition( 678 x, 679 RowPartition.from_row_splits(row_splits=[0, 4, 7, 8], validate=False)) 680 rt2 = rt1 * [[10], [100], [1000]] 681 return rt2.flat_values 682 683 x = constant_op.constant([3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0]) 684 y = func(x) 685 g = gradients_impl.gradients(ys=y, xs=x)[0] 686 687 self.assertAllClose(ops.convert_to_tensor(g), 688 [10., 10., 10., 10., 100., 100., 100, 1000.]) 689 690 def testRaggedGradientSimple3(self): 691 if context.executing_eagerly(): 692 return 693 def func(x): 694 rt1 = RaggedTensor._from_row_partition( 695 x, 696 RowPartition.from_row_splits(row_splits=[0, 4, 7, 8], 697 dtype=dtypes.int32, validate=False)) 698 rt2 = rt1 * [[10], [100], [1000]] 699 return rt2.flat_values 700 701 x = constant_op.constant([3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0]) 702 y = func(x) 703 g = gradients_impl.gradients(ys=y, xs=x)[0] 704 705 self.assertAllClose(ops.convert_to_tensor(g), 706 [10., 10., 10., 10., 100., 100., 100, 1000.]) 707 708 def testRaggedMul(self): 709 if context.executing_eagerly(): 710 return 711 x = constant_op.constant([3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0]) 712 rt1 = RaggedTensor._from_row_partition( 713 x, 714 RowPartition.from_row_splits(row_splits=[0, 4, 7, 8], 715 dtype=dtypes.int64, validate=False)) 716 rt2 = rt1 * [[10], [100], [1000]] 717 self.assertAllClose(rt2.flat_values, 718 [30.0, 10.0, 40.0, 10.0, 100.0, 0.0, 200.0, 1000.0]) 719 720 def testBroadcastToGradient(self): 721 if context.executing_eagerly(): 722 return 723 def func(x): 724 target_shape = DynamicRaggedShape.from_row_partitions( 725 [RowPartition.from_row_splits(row_splits=[0, 4, 7, 8])]) 726 727 rt = dynamic_ragged_shape.broadcast_to(x, target_shape) 728 return rt.flat_values 729 730 x = constant_op.constant([[3.0], [1.0], [4.0]]) 731 y = func(x) 732 g = gradients_impl.gradients(ys=y, xs=x)[0] 733 734 self.assertAllClose(g, [[4.], [3.], [1.]]) 735 736 def testBroadcastScalarToScalar(self): 737 origin = constant_op.constant(b'x') 738 expected = origin 739 expected_shape = DynamicRaggedShape.from_tensor(expected) 740 actual = dynamic_ragged_shape.broadcast_to(origin, expected_shape) 741 self.assertAllEqual(actual, expected) 742 743 @parameterized.parameters([ 744 dict(lengths=[2, 3], axis=0), 745 dict(lengths=[2, 3], axis=1), 746 dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=0), 747 dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=2), 748 dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=3), 749 ]) 750 def testIsUniformTrue(self, lengths, axis, num_row_partitions=None): 751 shape = DynamicRaggedShape.from_lengths(lengths) 752 if num_row_partitions is not None: 753 shape = shape._with_num_row_partitions(num_row_partitions) 754 actual = shape.is_uniform(axis) 755 self.assertTrue(actual) 756 757 @parameterized.parameters([ 758 dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=1), 759 dict( 760 lengths=[2, (2, 3), 2, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 4], 761 num_row_partitions=3, 762 axis=3), 763 ]) 764 def testIsUniformFalse(self, lengths, num_row_partitions, axis): 765 shape = DynamicRaggedShape.from_lengths(lengths)._with_num_row_partitions( 766 num_row_partitions) 767 actual = shape.is_uniform(axis) 768 self.assertFalse(actual) 769 770 @parameterized.parameters([ 771 dict( 772 lengths=[2, (2, 3), 7, 4], 773 num_row_partitions=2, 774 axis=10, 775 error_type=IndexError, 776 error_regex='Expected axis=10 < rank=4'), 777 dict( 778 lengths=[2, (2, 3), 7, 4], 779 num_row_partitions=2, 780 axis=-1, 781 error_type=IndexError, 782 error_regex='Negative axis values are not supported'), 783 dict( 784 lengths=[2, (2, 3), 7, 4], 785 num_row_partitions=2, 786 axis=0.5, 787 error_type=TypeError, 788 error_regex='axis must be an integer'), 789 ]) 790 def testIsUniformRaises(self, lengths, num_row_partitions, axis, error_type, 791 error_regex): 792 shape = DynamicRaggedShape.from_lengths(lengths)._with_num_row_partitions( 793 num_row_partitions) 794 with self.assertRaisesRegex(error_type, error_regex): 795 shape.is_uniform(axis) 796 797 @parameterized.parameters([ 798 dict(lengths=[2, 3], num_row_partitions_a=0, num_row_partitions_b=1), 799 dict( 800 lengths=[2, (2, 3), 7, 4], 801 num_row_partitions_a=2, 802 num_row_partitions_b=1), 803 dict( 804 lengths=[3, (2, 0, 1), 5], 805 num_row_partitions_a=1, 806 num_row_partitions_b=2) 807 ]) 808 def testWithNumRowPartitions(self, lengths, num_row_partitions_a, 809 num_row_partitions_b): 810 shape = DynamicRaggedShape.from_lengths(lengths) 811 original_row_partitions = shape.num_row_partitions 812 shape_a = shape._with_num_row_partitions(num_row_partitions_a) 813 self.assertEqual(shape_a.num_row_partitions, num_row_partitions_a) 814 shape_b = shape_a._with_num_row_partitions(num_row_partitions_b) 815 self.assertEqual(shape_b.num_row_partitions, num_row_partitions_b) 816 actual = shape_b._with_num_row_partitions(original_row_partitions) 817 self.assertShapeEq(actual, shape) 818 819 @parameterized.parameters([ 820 dict( 821 lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=-2, expected=7), 822 dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=0, expected=2), 823 dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=2, expected=7), 824 dict(lengths=[2, (2, 3), 7, 4], num_row_partitions=2, axis=3, expected=4), 825 dict( 826 lengths=[2, (2, 3), 7, 4, 3], 827 num_row_partitions=2, 828 axis=4, 829 expected=3), 830 dict(lengths=[3], axis=0, expected=3), 831 dict(lengths=[3, 4, 5], axis=0, expected=3), 832 dict(lengths=[3, 4, 5], axis=1, expected=4), 833 dict(lengths=[3, 4, 5], axis=2, expected=5), 834 ]) 835 def testGetItem(self, lengths, axis, expected, num_row_partitions=None): 836 shape = DynamicRaggedShape.from_lengths(lengths) 837 if num_row_partitions is not None: 838 shape = shape._with_num_row_partitions(num_row_partitions) 839 actual = shape[axis] 840 self.assertAllEqual(actual, expected) 841 842 def testNumElements(self): 843 shape = DynamicRaggedShape.from_lengths([2, 3, 4, 844 5])._with_num_row_partitions(2) 845 self.assertAllEqual(shape._num_elements(), 120) 846 847 def test_to_row_partitions_from_lengths(self): 848 # Testing the test. 849 actual = _to_row_partitions_from_lengths([1, 2, 3]) 850 expected = [ 851 RowPartition.from_row_splits([0, 2]), 852 RowPartition.from_row_splits([0, 3, 6]) 853 ] 854 self.assertRowPartitionEq(actual[0], expected[0]) 855 self.assertRowPartitionEq(actual[1], expected[1]) 856 857 @parameterized.parameters([ 858 dict( 859 origin=b'x', 860 expected_lengths=[2, (1, 2)], 861 expected=[[b'x'], [b'x', b'x']]), 862 dict( 863 origin=b'x', 864 expected_lengths=[1, 1, 1], 865 expected_num_row_partitions=2, 866 expected=[[[b'x']]]), 867 dict( 868 origin=[b'a', b'b', b'c'], 869 expected_lengths=[3], 870 expected=[b'a', b'b', b'c']), 871 dict( 872 origin=[b'a', b'b', b'c'], 873 expected_lengths=[1, 1, 3], 874 expected_num_row_partitions=2, 875 expected=[[[b'a', b'b', b'c']]]), 876 dict( 877 origin=[[b'a', b'b', b'c'], [b'd', b'e', b'f']], 878 expected_lengths=[1, 2, 3], 879 expected_num_row_partitions=2, 880 expected=[[[b'a', b'b', b'c'], [b'd', b'e', b'f']]]), 881 ]) 882 def testBroadcastTensorTo(self, 883 origin, 884 expected_lengths, 885 expected, 886 expected_num_row_partitions=None): 887 origin = constant_op.constant(origin) 888 expected_shape = DynamicRaggedShape.from_lengths(expected_lengths) 889 if expected_num_row_partitions is not None: 890 expected_shape = expected_shape._with_num_row_partitions( 891 expected_num_row_partitions) 892 expected = ragged_factory_ops.constant_value(expected) 893 actual = dynamic_ragged_shape.broadcast_to(origin, expected_shape) 894 self.assertAllEqual(actual, expected) 895 896 def testBroadcastFlatValues(self): 897 origin_lengths = [3, (1, 2, 1), 2, 2] 898 dest_lengths = [1, 1, 3, (1, 2, 1), 2, 2] 899 origin_values = constant_op.constant([ 900 b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', b'l', 901 b'm', b'n', b'o', b'p' 902 ]) 903 origin_shape = DynamicRaggedShape.from_lengths( 904 origin_lengths)._with_num_row_partitions(3) 905 dest_shape = DynamicRaggedShape.from_lengths( 906 dest_lengths)._with_num_row_partitions(5) 907 908 broadcaster = dynamic_ragged_shape._get_broadcaster(origin_shape, 909 dest_shape) 910 911 actual = broadcaster.broadcast_flat_values(origin_values) 912 913 self.assertAllEqual(origin_values, actual) 914 915 @parameterized.parameters([ 916 dict( 917 origin_lengths=[3], 918 origin_values=[b'a', b'b', b'c'], 919 expected_lengths=[2], 920 expected_values=[[b'a', b'b', b'c'], [b'a', b'b', b'c']]), 921 dict( 922 origin_lengths=[3, (3, 2, 4)], 923 origin_values=[7, 4, 5, 6, 1, 2, 3, 7, 89], 924 expected_lengths=[3, (3, 2, 4)], 925 expected_values=[7, 4, 5, 6, 1, 2, 3, 7, 89]), 926 dict( 927 origin_lengths=[3, (3, 2, 4)], 928 origin_values=[7, 4, 5, 6, 1, 2, 3, 7, 89], 929 expected_lengths=[1, 3, (3, 2, 4)], 930 expected_values=[7, 4, 5, 6, 1, 2, 3, 7, 89]), 931 dict( 932 origin_lengths=[3, (3, 2, 4)], 933 origin_values=[7, 4, 5, 6, 1, 2, 3, 7, 89], 934 expected_lengths=[1, 1, 3, (3, 2, 4)], 935 expected_values=[7, 4, 5, 6, 1, 2, 3, 7, 89]), 936 # Broadcast [1, 2, (1, 2)] to [2, 2, (1, 2, 1, 2)] 937 dict( 938 origin_lengths=[1, 2, (1, 2)], 939 origin_values=[2, 3, 5], 940 expected_lengths=[2, 2, (1, 2, 1, 2)], 941 expected_values=[2, 3, 5, 2, 3, 5]), 942 # Broadcast [2, 1, (1, 2)] to [2, 2, (1, 1, 2, 2)] (NEW) 943 dict( 944 origin_lengths=[2, 1, (1, 2)], 945 origin_values=[2, 3, 5], 946 expected_lengths=[2, 2, (1, 1, 2, 2)], 947 expected_values=[2, 2, 3, 5, 3, 5]), 948 dict( 949 origin_lengths=[2, 1, 1], 950 origin_values=[2, 3], # [[[2]], [[3]]] 951 expected_lengths=[2, 1, (3, 3)], 952 expected_values=[2, 2, 2, 3, 3, 3]), 953 dict( 954 origin_lengths=[3], 955 origin_values=[b'a', b'b', b'c'], 956 expected_lengths=[4, 2, 3], 957 expected_values=[ 958 b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', 959 b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', 960 b'b', b'c' 961 ]), 962 dict( 963 origin_lengths=[2, 3], 964 origin_values=[b'a', b'b', b'c', b'a', b'b', b'c'], 965 expected_lengths=[4, 2, 3], 966 expected_values=[ 967 b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', 968 b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', b'b', b'c', b'a', 969 b'b', b'c' 970 ]), 971 dict( 972 origin_lengths=[3, (1, 2, 1), 2, 2], 973 origin_values=[ 974 b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', 975 b'l', b'm', b'n', b'o', b'p' 976 ], 977 expected_lengths=[1, 1, 3, (1, 2, 1), 2, 2], 978 expected_values=[ 979 b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', b'k', 980 b'l', b'm', b'n', b'o', b'p' 981 ]), 982 dict( 983 origin_lengths=[3, (1, 2, 1), 2, 2], 984 origin_values=[7, 4, 5, 6, 1, 2, 3, 7, 7, 4, 5, 6, 1, 2, 3, 7], 985 expected_lengths=[1, 1, 3, (1, 2, 1), 2, 2], 986 expected_values=[7, 4, 5, 6, 1, 2, 3, 7, 7, 4, 5, 6, 1, 2, 3, 7], 987 ), 988 ]) 989 def testBroadcastRaggedTo(self, origin_lengths, origin_values, 990 expected_lengths, expected_values): 991 origin = _to_ragged_tensor_from_lengths(origin_values, origin_lengths) 992 expected = _to_ragged_tensor_from_lengths(expected_values, expected_lengths) 993 expected_shape = DynamicRaggedShape.from_tensor(expected) 994 actual = dynamic_ragged_shape.broadcast_to(origin, expected_shape) 995 self.assertAllEqual(actual, expected) 996 997 def testDynamicRaggedShapeFromTensor2(self): 998 raw_rt = [[[[7, 4], [5, 6]], [[1, 2], [3, 7]]], [[[7, 4], [5, 6]]], 999 [[[1, 2], [3, 7]]]] 1000 raw_rt = ragged_factory_ops.constant_value(raw_rt) 1001 actual_shape = DynamicRaggedShape.from_tensor(raw_rt) 1002 expected_shape = DynamicRaggedShape.from_lengths( 1003 [3, (2, 1, 1), 2, 2])._with_num_row_partitions(3) 1004 self.assertShapeEq(actual_shape, expected_shape) 1005 1006 # pylint: disable=g-long-lambda 1007 @parameterized.parameters([ 1008 # A row partition as opposed to a list of row partitions. 1009 dict( 1010 row_partitions=lambda: RowPartition.from_row_splits([0, 2, 3]), 1011 inner_shape=lambda: [4], 1012 error_type=TypeError, 1013 error_regex='row_partitions should be'), 1014 # A list of lists of integers for row_partitions. 1015 dict( 1016 row_partitions=lambda: [[0, 2, 3]], 1017 inner_shape=lambda: [4], 1018 error_type=TypeError, 1019 error_regex='row_partitions contains'), 1020 # nvals and nrows don't match (3 != 6) statically 1021 dict( 1022 row_partitions=lambda: [ # pylint: disable=g-long-lambda 1023 RowPartition.from_value_rowids([0, 2, 4], nrows=5), 1024 RowPartition.from_value_rowids([0, 2, 5], nrows=6) 1025 ], 1026 inner_shape=lambda: [3], 1027 validate=False, 1028 error_type=ValueError, 1029 error_regex='RowPartitions in DynamicRaggedShape do not'), 1030 # nvals and inner_shape[0] don't match (3 != 6) statically 1031 dict( 1032 row_partitions=lambda: [ 1033 RowPartition.from_value_rowids([0, 2, 4], nrows=5), 1034 ], 1035 inner_shape=lambda: [6], 1036 validate=False, 1037 error_type=ValueError, 1038 error_regex='Last row partition does not match inner_shape.'), 1039 ]) 1040 def testConstructorRaisesStatic(self, 1041 row_partitions, 1042 inner_shape, 1043 error_type, 1044 error_regex, 1045 validate=False, 1046 dtype=None): 1047 row_partitions = row_partitions() 1048 inner_shape = inner_shape() 1049 with self.assertRaisesRegex(error_type, error_regex): 1050 DynamicRaggedShape( 1051 row_partitions, inner_shape, dtype=dtype, validate=validate) 1052 1053 def testConstructorStaticOK(self): 1054 row_partitions = [ 1055 RowPartition.from_value_rowids([0, 2, 4], nrows=5), 1056 RowPartition.from_value_rowids([0, 1, 2], nrows=3) 1057 ] 1058 inner_shape = [3] 1059 rts = DynamicRaggedShape(row_partitions, inner_shape, validate=True) 1060 static_inner_shape = tensor_util.constant_value(rts.inner_shape) 1061 static_valid_rowids0 = tensor_util.constant_value( 1062 rts.row_partitions[0].value_rowids()) 1063 static_valid_rowids1 = tensor_util.constant_value( 1064 rts.row_partitions[1].value_rowids()) 1065 self.assertAllEqual(static_inner_shape, [3]) 1066 self.assertAllEqual(static_valid_rowids0, [0, 2, 4]) 1067 self.assertAllEqual(static_valid_rowids1, [0, 1, 2]) 1068 1069 def testConstructorWithStaticInnerShape(self): 1070 row_partitions = [ 1071 RowPartition.from_value_rowids([0, 2, 4], nrows=5), 1072 RowPartition.from_value_rowids([0, 1, 2], nrows=3) 1073 ] 1074 inner_shape = [3] 1075 rts = DynamicRaggedShape(row_partitions, inner_shape, validate=True, 1076 static_inner_shape=[3]) 1077 static_inner_shape = tensor_util.constant_value(rts.inner_shape) 1078 static_valid_rowids0 = tensor_util.constant_value( 1079 rts.row_partitions[0].value_rowids()) 1080 static_valid_rowids1 = tensor_util.constant_value( 1081 rts.row_partitions[1].value_rowids()) 1082 self.assertAllEqual(static_inner_shape, [3]) 1083 self.assertAllEqual(static_valid_rowids0, [0, 2, 4]) 1084 self.assertAllEqual(static_valid_rowids1, [0, 1, 2]) 1085 1086 def testZeros(self): 1087 shape_x = DynamicRaggedShape.from_lengths([3, (1, 3, 2), 4]) 1088 foo = ragged_array_ops.zeros(shape_x) 1089 self.assertShapeEq(shape_x, DynamicRaggedShape.from_tensor(foo)) 1090 self.assertAllEqual(array_ops.zeros([6, 4]), foo.flat_values) 1091 1092 def testOnes(self): 1093 shape_x = DynamicRaggedShape.from_lengths([3, (1, 3, 2), 4]) 1094 foo = ragged_array_ops.ones(shape_x) 1095 self.assertShapeEq(shape_x, DynamicRaggedShape.from_tensor(foo)) 1096 self.assertAllEqual(array_ops.ones([6, 4]), foo.flat_values) 1097 1098 def testReshapeTensor(self): 1099 foo = array_ops.zeros([3, 2, 4]) 1100 shape_b = DynamicRaggedShape.from_lengths([3, (3, 2, 1), 4]) 1101 result = ragged_array_ops.ragged_reshape(foo, shape_b) 1102 self.assertShapeEq(shape_b, DynamicRaggedShape.from_tensor(result)) 1103 self.assertAllEqual(array_ops.zeros([6, 4]), result.flat_values) 1104 1105 def test_reshape_ragged_tensor(self): 1106 shape_x = DynamicRaggedShape.from_lengths([3, (1, 3, 2), 4]) 1107 foo = ragged_array_ops.zeros(shape_x) 1108 shape_b = DynamicRaggedShape.from_lengths([3, (3, 2, 1), 4]) 1109 result = ragged_array_ops.ragged_reshape(foo, shape_b) 1110 self.assertShapeEq(shape_b, DynamicRaggedShape.from_tensor(result)) 1111 self.assertAllEqual(array_ops.zeros([6, 4]), result.flat_values) 1112 1113 @parameterized.parameters([ 1114 dict( 1115 lengths_a=[3, (1, 4, 2)], 1116 lengths_b=[3, (1, 4, 2)], 1117 lengths_e=[3, (1, 4, 2)]), 1118 dict( 1119 lengths_a=[1, 2, (1, 4)], 1120 lengths_b=[3, 2, (1, 4, 1, 4, 1, 4)], 1121 lengths_e=[3, 2, (1, 4, 1, 4, 1, 4)]), 1122 dict( 1123 lengths_a=[1, 1], 1124 num_row_partitions_a=1, 1125 lengths_b=[3, 5], 1126 num_row_partitions_b=1, 1127 lengths_e=[3, 5], 1128 num_row_partitions_e=1), 1129 dict(lengths_a=[1, 4, 5], lengths_b=[3, 1, 1], lengths_e=[3, 4, 5]), 1130 dict(lengths_a=[3], lengths_b=[4, 2, 1], lengths_e=[4, 2, 3]), 1131 dict(lengths_a=[2, 3], lengths_b=[4, 2, 1], lengths_e=[4, 2, 3]), 1132 # Outermost dimension-both partitioned 1133 # Also, neither has uniform_row_length 1134 dict( 1135 lengths_a=[2, (1, 3), 1], 1136 lengths_b=[2, (1, 3), (3, 4, 5, 6)], 1137 lengths_e=[2, (1, 3), (3, 4, 5, 6)]), 1138 # Outermost dimension-Only one is partitioned 1139 # Also, partitioned dimension doesn't have uniform_row_length 1140 dict( 1141 lengths_a=[2, 1, 5], 1142 lengths_b=[2, (1, 3), 5], 1143 num_row_partitions_b=2, 1144 lengths_e=[2, (1, 3), 5], 1145 num_row_partitions_e=2), 1146 1147 # Cover [5, R], [1, 5, R] 1148 dict( 1149 lengths_a=[5, (1, 2, 0, 3, 1)], 1150 lengths_b=[1, 5, (1, 2, 0, 3, 1)], 1151 lengths_e=[1, 5, (1, 2, 0, 3, 1)]), 1152 # When two uniform row lengths are equal 1153 dict( 1154 lengths_a=[1, 5], 1155 num_row_partitions_a=1, 1156 lengths_b=[3, 5], 1157 num_row_partitions_b=1, 1158 lengths_e=[3, 5], 1159 num_row_partitions_e=1), 1160 # Dense + Partitioned dimension has uniform_row_length 1161 # [1, 3, [5, 1, 6]] and DENSE [2, 1, 1] -> [2, 3, [5, 1, 6, 5, 1, 6]] 1162 dict( 1163 lengths_a=[1, 3, (5, 1, 6)], 1164 lengths_b=[2, 1, 1], 1165 lengths_e=[2, 3, (5, 1, 6, 5, 1, 6)]), 1166 # Both partitioned; one has uniform_row_length 1167 # (uniform_row_length [2,1,1]) and [2,[1,3],[3,4,5,6]] 1168 dict( 1169 lengths_a=[2, 1, 1], 1170 num_row_partitions_a=2, 1171 lengths_b=[2, (1, 3), (3, 4, 5, 6)], 1172 lengths_e=[2, (1, 3), (3, 4, 5, 6)]), 1173 # When broadcasting uniform_row_length to uniform_row_length. 1174 # Also, both have uniform_row_length 1175 dict( 1176 lengths_a=[3, 1, 5], 1177 num_row_partitions_a=2, 1178 lengths_b=[3, 4, 5], 1179 num_row_partitions_b=2, 1180 lengths_e=[3, 4, 5], 1181 num_row_partitions_e=2), 1182 # When broadcasting above a U_R_L 1183 # [2,1, 5] and [2, [1,3], 5] -> [2, [1,3], 5] 1184 dict( 1185 lengths_a=[2, 1, 5], 1186 num_row_partitions_a=2, 1187 lengths_b=[2, (1, 3), 5], 1188 num_row_partitions_b=2, 1189 lengths_e=[2, (1, 3), 5], 1190 num_row_partitions_e=2), 1191 # What if the larger-dimensional shape has uniform_row_length on the 1192 # matching dim, but has larger dimensions above 1193 # ([3,1,5],[15]) vs ([2,1],[2])) 1194 dict( 1195 lengths_a=[3, 1, 5], 1196 num_row_partitions_a=2, 1197 lengths_b=[2, 1], 1198 num_row_partitions_b=1, 1199 lengths_e=[3, 2, 5], 1200 num_row_partitions_e=2), 1201 # Inner non-ragged dimensions 1202 # Can delegate to dense broadcast operations. 1203 # Implementation detail: not testable. 1204 # ([2, [1,2]],[3,2,1]) and ([2,1],[2,1,3]) 1205 dict( 1206 lengths_a=[2, (1, 2), 2, 1], 1207 lengths_b=[2, 1, 1, 3], 1208 num_row_partitions_b=1, 1209 lengths_e=[2, (1, 2), 2, 3], 1210 ), 1211 ]) 1212 def testBroadcastDynamicShapeExtended(self, 1213 lengths_a, 1214 lengths_b, 1215 lengths_e, 1216 num_row_partitions_a=None, 1217 num_row_partitions_b=None, 1218 num_row_partitions_e=None): 1219 # This test is predicated on the fact that broadcast_to is correct. 1220 # Thus, it tests: 1221 # Whether the shape generated is correct. 1222 # Whether broadcasting is the same as broadcast_to. 1223 # Instead of specifying values, it just uses primes. 1224 shape_a = DynamicRaggedShape.from_lengths(lengths_a) 1225 if num_row_partitions_a is not None: 1226 shape_a = shape_a._with_num_row_partitions(num_row_partitions_a) 1227 shape_b = DynamicRaggedShape.from_lengths(lengths_b) 1228 if num_row_partitions_b is not None: 1229 shape_b = shape_b._with_num_row_partitions(num_row_partitions_b) 1230 shape_e = DynamicRaggedShape.from_lengths(lengths_e) 1231 if num_row_partitions_e is not None: 1232 shape_e = shape_e._with_num_row_partitions(num_row_partitions_e) 1233 1234 [actual, bc_a, bc_b 1235 ] = dynamic_ragged_shape.broadcast_dynamic_shape_extended(shape_a, shape_b) 1236 [actual_rev, bc_b_rev, bc_a_rev 1237 ] = dynamic_ragged_shape.broadcast_dynamic_shape_extended(shape_b, shape_a) 1238 self.assertShapeEq(actual, shape_e) 1239 self.assertShapeEq(actual_rev, shape_e) 1240 1241 rt_a = ragged_array_ops.ragged_reshape( 1242 _lowest_primes(_num_elements_of_lengths(lengths_a)), shape_a) 1243 bc_a_actual = bc_a.broadcast(rt_a) 1244 bc_a_actual_rev = bc_a_rev.broadcast(rt_a) 1245 bc_a_expected = dynamic_ragged_shape.broadcast_to(rt_a, shape_e) 1246 self.assertAllEqual(bc_a_expected, bc_a_actual) 1247 self.assertAllEqual(bc_a_expected, bc_a_actual_rev) 1248 1249 rt_b = ragged_array_ops.ragged_reshape( 1250 _lowest_primes(_num_elements_of_lengths(lengths_b)), shape_b) 1251 bc_b_expected = dynamic_ragged_shape.broadcast_to(rt_b, shape_e) 1252 bc_b_actual = bc_b.broadcast(rt_b) 1253 bc_b_actual_rev = bc_b_rev.broadcast(rt_b) 1254 self.assertAllEqual(bc_b_expected, bc_b_actual) 1255 self.assertAllEqual(bc_b_expected, bc_b_actual_rev) 1256 1257 @parameterized.parameters([ 1258 dict( 1259 lengths=[3, (1, 4, 2)], 1260 dense_rank=1, 1261 lengths_e=[3, (1, 4, 2)], 1262 ), 1263 dict( 1264 lengths=[3, (1, 4, 2), 5], 1265 dense_rank=2, 1266 lengths_e=[3, (1, 4, 2), 5], 1267 ), 1268 dict( 1269 lengths=[3], 1270 dense_rank=1, 1271 lengths_e=[3], 1272 ), 1273 ]) 1274 def testWithDenseRank(self, lengths, dense_rank, lengths_e): 1275 # Makes little sense with from_lengths/_with_num_row_partitions. 1276 original = DynamicRaggedShape.from_lengths(lengths) 1277 actual = original._with_inner_rank(dense_rank) 1278 self.assertAllEqual(actual.inner_rank, dense_rank) 1279 self.assertAllEqual(actual.static_lengths(), lengths_e) 1280 1281 @parameterized.parameters([ 1282 dict( 1283 rps=[3, [1, 4, 2]], 1284 lengths_e=[3, (1, 4, 2)], 1285 num_row_partitions_e=1, 1286 ), 1287 dict( 1288 rps=[3, [1, 4, 2], 2], 1289 lengths_e=[3, (1, 4, 2), 2], 1290 num_row_partitions_e=2, 1291 ), 1292 ]) 1293 def testFromRowPartitions(self, rps, lengths_e, num_row_partitions_e): 1294 rps = _to_row_partitions_from_lengths(rps) 1295 actual = DynamicRaggedShape.from_row_partitions(rps) 1296 expected = DynamicRaggedShape.from_lengths( 1297 lengths_e)._with_num_row_partitions(num_row_partitions_e) 1298 self.assertShapeEq(expected, actual) 1299 1300 def testFromRowPartitionsError(self): 1301 with self.assertRaisesRegex(ValueError, 'row_partitions cannot be empty'): 1302 DynamicRaggedShape.from_row_partitions([]) 1303 1304 @parameterized.parameters([ 1305 #========================================================================= 1306 # dimension[axis] is uniform inner; and row_lengths is a scalar 1307 #========================================================================= 1308 # shape: [BROADCAST(UNIFORM), UNIFORM, UNIFORM] 1309 dict(original_lengths=[1, 4, 5], 1310 broadcast_lengths=[3, 4, 5]), 1311 # shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)] 1312 dict(original_lengths=[3, 4, 1], 1313 broadcast_lengths=[3, 4, 5]), 1314 # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)] 1315 dict(original_lengths=[3, (3, 2, 8), 1], 1316 broadcast_lengths=[3, (3, 2, 8), 5]), 1317 # shape: [UNIFORM, RAGGED, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM)] 1318 dict(original_lengths=[2, (2, 1), (3, 2, 8), 3, 4, 1], 1319 broadcast_lengths=[2, (2, 1), (3, 2, 8), 3, 4, 5]), 1320 1321 #========================================================================= 1322 # dimension[axis] is uniform inner; and row_lengths is a vector 1323 #========================================================================= 1324 # shape: [UNIFORM, BROADCAST(UNIFORM)] 1325 dict(original_lengths=[3, 1], 1326 broadcast_lengths=[3, (2, 0, 1)]), 1327 # shape: [UNIFORM, BROADCAST(UNIFORM), UNIFORM] 1328 dict(original_lengths=[3, 1, 5], 1329 broadcast_lengths=[3, (2, 0, 1), 5]), 1330 1331 # shape: [UNIFORM, UNIFORM, BROADCAST(UNIFORM)] 1332 dict(original_lengths=[4, 3, 1], 1333 broadcast_lengths=[4, 3, (2, 0, 1, 3, 8, 2, 3, 4, 1, 8, 7, 0)]), 1334 1335 # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM)] 1336 dict(original_lengths=[2, (2, 1), 1], 1337 broadcast_lengths=[2, (2, 1), (2, 5, 3)]), 1338 1339 # shape: [UNIFORM, RAGGED, UNIFORM, UNIFORM, BROADCAST(UNIFORM), UNIFORM] 1340 dict(original_lengths=[2, (2, 1), 3, 2, 1, 8], 1341 broadcast_lengths=[2, (2, 1), 3, 2, tuple(range(18)), 8]), 1342 1343 #========================================================================= 1344 # dimension[axis] is uniform partitioned; and row_lengths is a scalar 1345 #========================================================================= 1346 # shape: [BROADCAST(UNIFORM), RAGGED] 1347 dict(original_lengths=[1, (5,)], 1348 broadcast_lengths=[3, (5, 5, 5)]), 1349 1350 # shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED] 1351 dict(original_lengths=[1, 3, (3, 0, 2)], 1352 broadcast_lengths=[2, 3, (3, 0, 2, 3, 0, 2)]), 1353 1354 # shape: [BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM, UNIFORM] 1355 dict(original_lengths=[1, (3,), (3, 5, 2), 9, 4, 5], 1356 broadcast_lengths=[3, (3, 3, 3), (3, 5, 2, 3, 5, 2, 3, 5, 2), 1357 9, 4, 5]), 1358 1359 # shape: [BROADCAST(UNIFORM), UNIFORM, RAGGED, UNIFORM] 1360 dict(original_lengths=[1, 2, (2, 1), (3, 5, 2), 2], 1361 broadcast_lengths=[2, 2, (2, 1, 2, 1), (3, 5, 2, 3, 5, 2), 2]), 1362 1363 # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM] 1364 # This is wrong. should broadcast to [3, 2, (4, 4, 0, 0, 2, 2), 5] 1365 # dict(original_lengths=[3, 1, [4, 0, 2], 5], 1366 # broadcast_lengths=[3, 2, [4, 0, 2, 4, 0, 2], 5]), 1367 dict(original_lengths=[3, 1, (4, 0, 2), 5], 1368 broadcast_lengths=[3, 2, (4, 4, 0, 0, 2, 2), 5]), 1369 1370 # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED] 1371 dict(original_lengths=[2, 3, (1, 2, 3, 4, 5, 6)], 1372 broadcast_lengths=[2, 3, (1, 2, 3, 4, 5, 6)]), 1373 1374 #========================================================================= 1375 # dimension[axis] is uniform partitioned; and row_lengths is a vector 1376 #========================================================================= 1377 # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, UNIFORM] 1378 dict(original_lengths=[ 1379 3, # axis=0 1380 1, # axis=1 (broadcast) 1381 (3, 1, 2), # axis=2 1382 5], # axis=3 1383 broadcast_lengths=[ 1384 3, # axis=0 1385 (4, 1, 2), # axis=1 (broadcast) 1386 (3, 3, 3, 3, 1, 2, 2), # axis=2 1387 5]), # axis=3 1388 1389 # shape: [UNIFORM, BROADCAST(UNIFORM), RAGGED, RAGGED] 1390 dict(original_lengths=[ 1391 3, # axis=0 1392 1, # axis=1 (broadcast) 1393 (3, 1, 2), # axis=2 1394 (3, 1, 4, 1, 5, 9)], # axis=3 1395 broadcast_lengths=[ 1396 3, # axis=0 1397 (2, 0, 3), # axis=1 (broadcast) 1398 (3, 3, 2, 2, 2), # axis=2 1399 (3, 1, 4, 3, 1, 4, 5, 9, 5, 9, 5, 9)]), # axis=3 1400 1401 # shape: [UNIFORM, RAGGED, BROADCAST(UNIFORM), RAGGED, RAGGED, UNIFORM] 1402 dict(original_lengths=[ 1403 3, # axis=0 1404 (2, 0, 1), # axis=1 1405 1, # axis=2 (broadcast) 1406 (3, 2, 1), # axis=3 1407 (1, 0, 1, 0, 2, 3), # axis=4 1408 5], # axis=5 1409 broadcast_lengths=[ 1410 3, # axis=0 1411 (2, 0, 1), # axis=2 1412 (4, 1, 2), # axis=2 (broadcast) 1413 (3, 3, 3, 3, 2, 1, 1), # axis=3 1414 (1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, # axis=4 1415 2, 3, 3), 1416 5]), # axis=5 1417 dict(original_lengths=[1, 1, 2, (2, 1)], 1418 broadcast_lengths=[2, 1, 2, (2, 1, 2, 1)]), 1419 dict(original_lengths=[2, 1, 2, (2, 1, 2, 1)], 1420 broadcast_lengths=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]), 1421 dict(original_lengths=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)], 1422 broadcast_lengths=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]), 1423 dict(original_lengths=[2, (2, 1), 2, 1], 1424 broadcast_lengths=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]), 1425 ]) # pyformat: disable 1426 def testBroadcastDimension(self, original_lengths, broadcast_lengths): 1427 """Tests broadcast_to on a single dimension.""" 1428 original_rt = _to_prime_tensor_from_lengths(original_lengths) 1429 bcast_shape = DynamicRaggedShape.from_lengths(broadcast_lengths) 1430 result_rt = dynamic_ragged_shape.broadcast_to(original_rt, bcast_shape) 1431 result_shape = DynamicRaggedShape.from_tensor(result_rt) 1432 1433 self.assertShapeEq(bcast_shape, result_shape) 1434 1435 def testAsRowPartitions(self): 1436 my_shape = DynamicRaggedShape.from_lengths([3, (2, 0, 1), 5]) 1437 rps = my_shape._as_row_partitions() 1438 self.assertLen(rps, 2) 1439 1440 def testAsRowPartitionsRaises(self): 1441 my_shape = DynamicRaggedShape.from_lengths([]) 1442 with self.assertRaisesRegex(ValueError, 1443 'rank must be >= 1 for _as_row_partitions'): 1444 my_shape._as_row_partitions() 1445 1446 def testToPrimeTensorFromDimSizes(self): 1447 """Tests the test utility.""" 1448 original_lengths = [3, (3, 2, 8), 1] 1449 original_rt = _to_prime_tensor_from_lengths(original_lengths) 1450 expected_rt = _to_ragged_tensor_from_lengths( 1451 [[2], [3], [5], [7], [11], [13], [17], [19], [23], [29], [31], [37], 1452 [41]], [3, (3, 2, 8)]) 1453 1454 self.assertAllEqual(expected_rt, original_rt) 1455 1456 @parameterized.parameters([ 1457 # Broadcast scalar 1458 dict(x_dims=[], y_dims=[], expected_dims=[]), 1459 dict(x_dims=[], y_dims=[2], expected_dims=[2]), 1460 dict(x_dims=[], y_dims=[2, 3], expected_dims=[2, 3]), 1461 dict( 1462 x_dims=[], 1463 y_dims=[2, (2, 3), (5, 7, 2, 0, 9)], 1464 expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]), 1465 # Broadcast vector 1466 dict(x_dims=[3], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]), 1467 dict(x_dims=[1], y_dims=[4, 2, 3], expected_dims=[4, 2, 3]), 1468 dict(x_dims=[3], y_dims=[4, 2, 1], expected_dims=[4, 2, 3]), 1469 dict( 1470 x_dims=[3], y_dims=[3, (2, 3, 1), 1], expected_dims=[3, (2, 3, 1), 1471 3]), 1472 dict(x_dims=[1], y_dims=[3, (2, 1, 3)], expected_dims=[3, (2, 1, 3)]), 1473 dict( 1474 x_dims=[1], y_dims=[3, (2, 1, 3), 8], expected_dims=[3, (2, 1, 3), 1475 8]), 1476 dict( 1477 x_dims=[1], 1478 y_dims=[2, (2, 3), (5, 7, 2, 0, 9)], 1479 expected_dims=[2, (2, 3), (5, 7, 2, 0, 9)]), 1480 # Mixed broadcasting 1481 dict( 1482 x_dims=[ 1483 1, # axis=0 1484 3, # axis=1 1485 (3, 0, 2), # axis=2 1486 1, # axis=3 1487 2, # axis=4 1488 ], 1489 y_dims=[ 1490 2, # axis=0 1491 1, # axis=1 1492 1, # axis=2 1493 (7, 2), # axis=3 1494 1, # axis=4 1495 ], 1496 expected_dims=[ 1497 2, # axis=0 1498 3, # axis=1 1499 (3, 0, 2, 3, 0, 2), # axis=2 1500 (7, 7, 7, 7, 7, 2, 2, 2, 2, 2), # axis=3 1501 2, # axis=4 1502 ]), 1503 dict( 1504 x_dims=[2, (2, 1), 2, 1], 1505 y_dims=[1, 1, 2, (2, 1)], 1506 expected_dims=[2, (2, 1), 2, (2, 1, 2, 1, 2, 1)]), 1507 ]) 1508 def testBroadcastDynamicShape(self, x_dims, y_dims, expected_dims): 1509 shape_a = DynamicRaggedShape.from_lengths(x_dims) 1510 shape_b = DynamicRaggedShape.from_lengths(y_dims) 1511 shape_e = DynamicRaggedShape.from_lengths(expected_dims) 1512 [actual, bc_a, bc_b 1513 ] = dynamic_ragged_shape.broadcast_dynamic_shape_extended(shape_a, shape_b) 1514 [actual_rev, bc_b_rev, bc_a_rev 1515 ] = dynamic_ragged_shape.broadcast_dynamic_shape_extended(shape_b, shape_a) 1516 self.assertShapeEq(actual, shape_e) 1517 self.assertShapeEq(actual_rev, shape_e) 1518 1519 rt_a = _to_prime_tensor_from_lengths(x_dims) 1520 bc_a_actual = bc_a.broadcast(rt_a) 1521 bc_a_actual_rev = bc_a_rev.broadcast(rt_a) 1522 bc_a_expected = dynamic_ragged_shape.broadcast_to(rt_a, shape_e) 1523 self.assertAllEqual(bc_a_expected, bc_a_actual) 1524 self.assertAllEqual(bc_a_expected, bc_a_actual_rev) 1525 1526 rt_b = _to_prime_tensor_from_lengths(y_dims) 1527 bc_b_expected = dynamic_ragged_shape.broadcast_to(rt_b, shape_e) 1528 bc_b_actual = bc_b.broadcast(rt_b) 1529 bc_b_actual_rev = bc_b_rev.broadcast(rt_b) 1530 self.assertAllEqual(bc_b_expected, bc_b_actual) 1531 self.assertAllEqual(bc_b_expected, bc_b_actual_rev) 1532 1533 # This just wraps broadcast_dynamic_shape_extended, so nothing 1534 # deeper is required. 1535 result1 = dynamic_ragged_shape.broadcast_dynamic_shape(shape_a, shape_b) 1536 self.assertShapeEq(shape_e, result1) 1537 1538 # Again, just a wrapper. 1539 result2 = ragged_array_ops.broadcast_dynamic_shape(shape_a, shape_b) 1540 self.assertShapeEq(shape_e, result2) 1541 1542 def testBroadcastDynamicShapeFirstLayer(self): 1543 a_0 = constant_op.constant(1, dtypes.int64) 1544 b_0 = constant_op.constant(3, dtypes.int64) 1545 [a_layer, b_layer 1546 ] = dynamic_ragged_shape._broadcast_dynamic_shape_first_layer(a_0, b_0) 1547 expected_a_layer = _LayerBroadcaster.from_gather_index([0, 0, 0]) 1548 expected_b_layer = _LayerBroadcaster.from_gather_index([0, 1, 2]) 1549 self.assertLayerBroadcasterEq(expected_a_layer, a_layer) 1550 self.assertLayerBroadcasterEq(expected_b_layer, b_layer) 1551 1552 def testBroadcastDynamicShapeNextLayer(self): 1553 a_1 = RowPartition.from_uniform_row_length( 1554 1, nvals=1, nrows=1, dtype_hint=dtypes.int64) 1555 b_1 = RowPartition.from_row_lengths([2, 1, 3], dtype_hint=dtypes.int64) 1556 ac_0 = _LayerBroadcaster.from_gather_index( 1557 constant_op.constant([0, 0, 0], dtype=dtypes.int64)) 1558 bc_0 = _LayerBroadcaster.from_gather_index( 1559 constant_op.constant([0, 1, 2], dtype=dtypes.int64)) 1560 dynamic_ragged_shape._broadcast_dynamic_shape_next_layer_half_ragged( 1561 ac_0, bc_0, a_1, b_1) 1562 1563 def testBroadcastDynamicShapeRaisesLeft(self): 1564 shape = DynamicRaggedShape.from_tensor(constant_op.constant([1, 2, 3])) 1565 with self.assertRaisesRegex(TypeError, 'shape_x must be'): 1566 dynamic_ragged_shape.broadcast_dynamic_shape(1, shape) 1567 1568 def testBroadcastDynamicShapeRaisesRight(self): 1569 shape = DynamicRaggedShape.from_tensor(constant_op.constant([1, 2, 3])) 1570 with self.assertRaisesRegex(TypeError, 'shape_y must be'): 1571 dynamic_ragged_shape.broadcast_dynamic_shape(shape, 1) 1572 1573 def testBroadcastToRaises(self): 1574 rt = constant_op.constant([1, 2, 3]) 1575 with self.assertRaisesRegex(TypeError, 'shape must be'): 1576 dynamic_ragged_shape.broadcast_to(rt, 1) 1577 1578 @parameterized.parameters([ 1579 dict( 1580 x=[[10], [20], [30]], # shape=[3, 1] 1581 lengths=[3, 2], 1582 expected=[[10, 10], [20, 20], [30, 30]]), 1583 dict( 1584 x=[[10], [20], [30]], # shape=[3, 1] 1585 lengths=[3, (3, 0, 2)], 1586 expected=ragged_factory_ops.constant_value( 1587 [[10, 10, 10], [], [30, 30]], dtype=np.int32)), 1588 dict( 1589 x=[[[1, 2, 3]], [[4, 5, 6]]], # shape = [2, 1, 3] 1590 lengths=[2, (2, 3), 3], 1591 expected=ragged_factory_ops.constant_value( 1592 [[[1, 2, 3], [1, 2, 3]], [[4, 5, 6], [4, 5, 6], [4, 5, 6]]], 1593 dtype=np.int32, 1594 ragged_rank=1)), 1595 dict( 1596 x=[[[1]], [[2]]], # shape = [2, 1, 1] 1597 lengths=[2, (2, 3), (0, 2, 1, 2, 0)], 1598 expected=ragged_factory_ops.constant_value( 1599 [[[], [1, 1]], [[2], [2, 2], []]], dtype=np.int32, 1600 ragged_rank=2)), 1601 dict( 1602 x=10, 1603 lengths=[3, (3, 0, 2)], 1604 expected=ragged_factory_ops.constant_value([[10, 10, 10], [], 1605 [10, 10]])), 1606 dict( 1607 x=ragged_factory_ops.constant_value([[[1], [2]], [[3]]], 1608 ragged_rank=1), 1609 lengths=[2, (2, 1), 2], 1610 expected=ragged_factory_ops.constant_value( 1611 [[[1, 1], [2, 2]], [[3, 3]]], ragged_rank=1)), 1612 ]) 1613 def testRaggedBroadcastTo(self, x, lengths, expected): 1614 shape = DynamicRaggedShape.from_lengths(lengths) 1615 result = dynamic_ragged_shape.broadcast_to(x, shape) 1616 self.assertEqual( 1617 getattr(result, 'num_row_partitions', 0), 1618 getattr(expected, 'num_row_partitions', 0)) 1619 self.assertAllEqual(result, expected) 1620 1621 # broadcast_to just calls dynamic_ragged_shape.broadcast_to, so 1622 # this should be sufficient. 1623 result2 = ragged_array_ops.broadcast_to(x, shape) 1624 self.assertAllEqual(result2, expected) 1625 1626 @parameterized.parameters([ 1627 dict( 1628 doc='x.shape=[3, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]', 1629 x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]], 1630 dtype=np.int32), 1631 y=[[10], [20], [30]], 1632 expected=ragged_factory_ops.constant_value([[11, 12, 13], [], 1633 [34, 35]])), 1634 dict( 1635 doc='x.shape=[3, (D1)]; y.shape=[]; bcast.shape=[3, (D1)]', 1636 x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]], 1637 dtype=np.int32), 1638 y=10, 1639 expected=ragged_factory_ops.constant_value([[11, 12, 13], [], 1640 [14, 15]])), 1641 dict( 1642 doc='x.shape=[1, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]', 1643 x=ragged_factory_ops.constant_value([[1, 2, 3]], dtype=np.int32), 1644 y=[[10], [20], [30]], 1645 expected=ragged_factory_ops.constant_value( 1646 [[11, 12, 13], [21, 22, 23], [31, 32, 33]], dtype=np.int32)), 1647 dict( 1648 doc=('x.shape=[2, (D1), 1]; y.shape=[1, (D2)]; ' 1649 'bcast.shape=[2, (D1), (D2)]'), 1650 x=ragged_factory_ops.constant_value([[[1], [2], [3]], [[4]]], 1651 ragged_rank=1), 1652 y=ragged_factory_ops.constant_value([[10, 20, 30]]), 1653 expected=ragged_factory_ops.constant_value([[[11, 21, 1654 31], [12, 22, 32], 1655 [13, 23, 33]], 1656 [[14, 24, 34]]])), 1657 dict( 1658 doc=('x.shape=[2, (D1), 1]; y.shape=[1, 1, 4]; ' 1659 'bcast.shape=[2, (D1), 4]'), 1660 x=ragged_factory_ops.constant_value([[[10], [20]], [[30]]], 1661 ragged_rank=1), 1662 y=[[[1, 2, 3, 4]]], 1663 expected=ragged_factory_ops.constant_value( 1664 [[[11, 12, 13, 14], [21, 22, 23, 24]], [[31, 32, 33, 34]]], 1665 ragged_rank=1)), 1666 dict( 1667 doc=('x.shape=[2, (D1), 2, 1]; y.shape=[2, (D2)]; ' 1668 'bcast.shape=[2, (D1), (2), (D2)'), 1669 x=ragged_factory_ops.constant_value( 1670 [[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1), 1671 y=ragged_factory_ops.constant_value([[10, 20], [30]]), 1672 expected=ragged_factory_ops.constant_value([[[[11, 21], [32]], 1673 [[13, 23], [34]]], 1674 [[[15, 25], [36]]]])), 1675 ]) 1676 def testRaggedAddWithBroadcasting(self, x, y, expected, doc): 1677 expected_rrank = getattr(expected, 'num_row_partitions', 0) 1678 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32) 1679 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32) 1680 result = x + y 1681 result_rrank = getattr(result, 'num_row_partitions', 0) 1682 self.assertEqual(expected_rrank, result_rrank) 1683 if hasattr(expected, 'tolist'): 1684 expected = expected.tolist() 1685 self.assertAllEqual(result, expected) 1686 1687 @parameterized.parameters([ 1688 dict(lengths_a=[3, (1, 4, 2)], new_impl=True, op_max=10), # Actual ops: 5 1689 dict(lengths_a=[3, (1, 4, 2)], new_impl=False, op_max=300), 1690 ]) 1691 def testAddSelf(self, lengths_a, new_impl, op_max, num_row_partitions_a=None): 1692 if context.executing_eagerly(): 1693 return 1694 shape_a0 = DynamicRaggedShape.from_lengths( 1695 lengths_a, num_row_partitions=num_row_partitions_a) 1696 rt_a = ragged_array_ops.ragged_reshape( 1697 _lowest_primes(_num_elements_of_lengths(lengths_a)), shape_a0) 1698 rt_b = rt_a 1699 g = rt_a.flat_values.graph if ragged_tensor.is_ragged(rt_a) else rt_a.graph 1700 nodes_at_a = len(g.as_graph_def().node) 1701 if new_impl: 1702 dynamic_ragged_shape.ragged_binary_elementwise_op_impl( 1703 gen_math_ops.add_v2, rt_a, rt_b) 1704 nodes_at_b = len(g.as_graph_def().node) 1705 node_delta = nodes_at_b - nodes_at_a 1706 self.assertLessEqual(node_delta, op_max) 1707 else: 1708 if isinstance(rt_a, RaggedTensor): 1709 rt_a = rt_a.with_row_splits_dtype(dtypes.int32) 1710 rt_b = rt_a 1711 nodes_at_b = len(g.as_graph_def().node) 1712 rt_a + rt_b # pylint: disable=pointless-statement 1713 nodes_at_d = len(g.as_graph_def().node) 1714 node_delta = nodes_at_d - nodes_at_b 1715 self.assertLessEqual(node_delta, op_max) 1716 1717 def testAndSelfBool(self): 1718 if context.executing_eagerly(): 1719 return 1720 values = constant_op.constant([True, False, True, True, True]) 1721 rt_a = RaggedTensor.from_row_splits(values, [0, 3, 3, 5]) 1722 result = dynamic_ragged_shape.ragged_binary_elementwise_op_impl( 1723 gen_math_ops.logical_and, rt_a, rt_a) 1724 1725 expected_values = values 1726 expected = RaggedTensor.from_row_splits(expected_values, [0, 3, 3, 5]) 1727 1728 self.assertAllEqual(result, expected) 1729 1730 def testEquals(self): 1731 if context.executing_eagerly(): 1732 return 1733 1734 rt_a = ragged_factory_ops.constant([[3, 1, 3], [3]]) 1735 b = constant_op.constant(3) 1736 rt_expected = ragged_factory_ops.constant([[True, False, True], [True]]) 1737 1738 result = dynamic_ragged_shape.ragged_binary_elementwise_op_impl( 1739 math_ops.equal, rt_a, b) 1740 self.assertAllEqual(result, rt_expected) 1741 1742 def testEquals2(self): 1743 splits = constant_op.constant([0, 1]) 1744 a = RaggedTensor.from_row_splits([[1, 2]], splits) 1745 b = RaggedTensor.from_row_splits([[3, 4, 5]], splits) 1746 self.assertIs(a == b, False) 1747 1748 def testEquals3(self): 1749 a = RaggedTensor.from_row_splits([[1, 2]], [0, 1]) 1750 b = RaggedTensor.from_row_splits([[3, 4, 5]], [0, 1]) 1751 self.assertIs(a == b, False) 1752 1753 @parameterized.parameters([ 1754 dict( 1755 lengths_a=[3, (1, 4, 2)], lengths_b=[], new_impl=True, 1756 max_num_ops=5), # Actual ops: 1 1757 dict( 1758 lengths_a=[3, (1, 4, 2), 3, 2], 1759 lengths_b=[3, 2], 1760 new_impl=True, 1761 max_num_ops=5), # Actual ops: 1 1762 dict( 1763 lengths_a=[3, (1, 4, 2)], lengths_b=[], new_impl=False, 1764 max_num_ops=5), # Actual ops: 1 1765 dict( 1766 lengths_a=[3, (1, 4, 2), 3, 2], 1767 lengths_b=[3, 2], 1768 new_impl=False, 1769 max_num_ops=5), # Actual ops: 1 1770 ]) 1771 def testAdd(self, 1772 lengths_a, 1773 lengths_b, 1774 new_impl, 1775 max_num_ops, 1776 num_row_partitions_a=None, 1777 num_row_partitions_b=None): 1778 if context.executing_eagerly(): 1779 return 1780 1781 shape_a0 = DynamicRaggedShape.from_lengths( 1782 lengths_a, num_row_partitions=num_row_partitions_a) 1783 shape_b0 = DynamicRaggedShape.from_lengths( 1784 lengths_b, num_row_partitions=num_row_partitions_b) 1785 rt_a = ragged_array_ops.ragged_reshape( 1786 _lowest_primes(_num_elements_of_lengths(lengths_a)), shape_a0) 1787 rt_b = ragged_array_ops.ragged_reshape( 1788 _lowest_primes(_num_elements_of_lengths(lengths_b)), shape_b0) 1789 g = rt_a.flat_values.graph if ragged_tensor.is_ragged(rt_a) else rt_a.graph 1790 1791 nodes_at_a = len(g.as_graph_def().node) 1792 if new_impl: 1793 dynamic_ragged_shape.ragged_binary_elementwise_op_impl( 1794 gen_math_ops.add_v2, 1795 rt_a, 1796 rt_b) 1797 nodes_at_b = len(g.as_graph_def().node) 1798 num_nodes = nodes_at_b - nodes_at_a 1799 self.assertLessEqual(num_nodes, max_num_ops) 1800 else: 1801 if isinstance(rt_a, RaggedTensor): 1802 rt_a = rt_a.with_row_splits_dtype(dtypes.int32) 1803 if isinstance(rt_b, RaggedTensor): 1804 rt_b = rt_b.with_row_splits_dtype(dtypes.int32) 1805 nodes_at_b = len(g.as_graph_def().node) 1806 rt_a + rt_b # pylint: disable=pointless-statement 1807 nodes_at_d = len(g.as_graph_def().node) 1808 num_nodes = nodes_at_d - nodes_at_b 1809 1810 @parameterized.parameters([ 1811 dict( 1812 lengths_a=[3, (1, 4, 2)], lengths_b=[], 1813 shape_e=[3, None], new_impl=False), 1814 dict( 1815 lengths_a=[3, (1, 4, 2)], lengths_b=[], 1816 shape_e=[3, None], new_impl=True), 1817 dict( 1818 lengths_a=[5, (1, 4, 2, 1, 3), 3], 1819 lengths_b=[5, 1, 3], 1820 shape_e=[5, None, 3], new_impl=False), 1821 dict( 1822 lengths_a=[5, (1, 4, 2, 1, 3), 3], 1823 lengths_b=[5, 1, 3], 1824 shape_e=[5, None, 3], new_impl=True), 1825 dict( 1826 lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3], 1827 lengths_b=[3, 2, 1, 3], 1828 shape_e=[3, 2, None, 3], new_impl=False), 1829 dict( 1830 lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3], 1831 lengths_b=[3, 2, 1, 3], 1832 shape_e=[3, 2, None, 3], 1833 new_impl=True), 1834 dict( 1835 lengths_a=[3, (1, 4, 2)], lengths_b=[3, 1], 1836 shape_e=[3, None], new_impl=False), 1837 dict( 1838 lengths_a=[3, (1, 4, 2)], lengths_b=[3, 1], 1839 shape_e=[3, None], new_impl=True), 1840 1841 ]) 1842 def testAddShape(self, 1843 lengths_a, 1844 lengths_b, 1845 shape_e, 1846 new_impl=False, 1847 num_row_partitions_a=None, 1848 num_row_partitions_b=None): 1849 if context.executing_eagerly(): 1850 return 1851 shape_a = DynamicRaggedShape.from_lengths( 1852 lengths_a, num_row_partitions=num_row_partitions_a) 1853 shape_b = DynamicRaggedShape.from_lengths( 1854 lengths_b, num_row_partitions=num_row_partitions_b) 1855 rt_a = ragged_array_ops.ragged_reshape( 1856 _lowest_primes(_num_elements_of_lengths(lengths_a)), shape_a) 1857 rt_b = ragged_array_ops.ragged_reshape( 1858 _lowest_primes(_num_elements_of_lengths(lengths_b)), shape_b) 1859 if new_impl: 1860 result = dynamic_ragged_shape.ragged_binary_elementwise_op_impl( 1861 math_ops.add, rt_a, rt_b) 1862 shape_e = tensor_shape.TensorShape(shape_e) 1863 self.assertEqual(shape_e.as_list(), result.shape.as_list()) 1864 else: 1865 if isinstance(rt_a, RaggedTensor): 1866 rt_a = rt_a.with_row_splits_dtype(dtypes.int32) 1867 if isinstance(rt_b, RaggedTensor): 1868 rt_b = rt_b.with_row_splits_dtype(dtypes.int32) 1869 result = rt_a + rt_b 1870 shape_e = tensor_shape.TensorShape(shape_e) 1871 self.assertEqual(shape_e.as_list(), result.shape.as_list()) 1872 1873 @parameterized.parameters([ 1874 dict( 1875 lengths_a=[3, (1, 4, 2)], lengths_b=[], 1876 shape_e=[3, (1, 4, 2)]), 1877 dict( 1878 lengths_a=[5], lengths_b=[1], 1879 shape_e=[5]), 1880 dict( 1881 lengths_a=[5, (1, 4, 2, 1, 3), 3], 1882 lengths_b=[5, 1, 3], 1883 shape_e=[5, None, 3]), 1884 dict( 1885 lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3], 1886 lengths_b=[3, 2, 1, 3], 1887 shape_e=[3, 2, None, 3]), 1888 dict(lengths_a=[3, (1, 4, 2)], lengths_b=[3, 1], shape_e=[3, None]), 1889 dict(lengths_a=[5, 1, 3], lengths_b=[2, 3], shape_e=[5, 2, 3]), 1890 dict(lengths_a=[5, 1, (3, 2, 4, 1, 3)], lengths_b=[2, 1], 1891 shape_e=[5, 2, None]), 1892 dict(lengths_a=[5, 4, 1, 3], lengths_b=[2, 1], shape_e=[5, 4, 2, 3]), 1893 ]) 1894 def testBroadcastDynamicShapeStatic(self, 1895 lengths_a, 1896 lengths_b, 1897 shape_e, 1898 num_row_partitions_a=None, 1899 num_row_partitions_b=None): 1900 if context.executing_eagerly(): 1901 return 1902 shape_a = DynamicRaggedShape.from_lengths( 1903 lengths_a, num_row_partitions=num_row_partitions_a) 1904 shape_b = DynamicRaggedShape.from_lengths( 1905 lengths_b, num_row_partitions=num_row_partitions_b) 1906 1907 result = dynamic_ragged_shape.broadcast_dynamic_shape(shape_a, shape_b) 1908 result_shape = result._to_tensor_shape() 1909 1910 tensor_shape_e = [None if isinstance(x, tuple) else x for x in shape_e] 1911 self.assertEqual(shape_e, result.static_lengths()) 1912 self.assertEqual(tensor_shape_e, result_shape.as_list()) 1913 1914 def testBroadcastDynamicShapePartiallyKnown(self): 1915 if context.executing_eagerly(): 1916 return 1917 @def_function.function( 1918 input_signature=[tensor_spec.TensorSpec(None, dtypes.int64)]) 1919 def fun(x): 1920 shape_a = DynamicRaggedShape([], array_ops.stack([5, x, 3])) 1921 shape_b = DynamicRaggedShape.from_lengths([1, 3], dtype=dtypes.int64) 1922 result = dynamic_ragged_shape.broadcast_dynamic_shape(shape_a, shape_b) 1923 self.assertAllEqual([5, None, 3], result.static_lengths()) 1924 fun(constant_op.constant(2, dtype=dtypes.int64)) 1925 1926 def testBroadcastDynamicShapePartiallyKnownNiceToHave(self): 1927 if context.executing_eagerly(): 1928 return 1929 @def_function.function( 1930 input_signature=[tensor_spec.TensorSpec(None, dtypes.int64)]) 1931 def fun(x): 1932 shape_a = DynamicRaggedShape([], array_ops.stack([5, x, 3])) 1933 shape_b = DynamicRaggedShape.from_lengths([2, 3], dtype=dtypes.int64) 1934 result = dynamic_ragged_shape.broadcast_dynamic_shape(shape_a, shape_b) 1935 self.assertAllEqual([5, 2, 3], result.static_lengths()) 1936 fun(constant_op.constant(2, dtype=dtypes.int64)) 1937 1938 def testFromRowPartitionsStatic(self): 1939 if context.executing_eagerly(): 1940 return 1941 rp = RowPartition.from_row_lengths([4, 2, 3]) 1942 result = DynamicRaggedShape.from_row_partitions([rp]) 1943 self.assertEqual([3, (4, 2, 3)], result.static_lengths()) 1944 1945 @parameterized.parameters([ 1946 dict( 1947 lengths_a=[3, (1, 4, 2)], dim=0, 1948 expected=3), 1949 dict( 1950 lengths_a=[5], dim=0, 1951 expected=5), 1952 dict( 1953 lengths_a=[5, (1, 4, 2, 1, 3), 3], 1954 dim=0, 1955 expected=5), 1956 dict( 1957 lengths_a=[5, (1, 4, 2, 1, 3), 3], 1958 dim=2, 1959 expected=3), 1960 dict( 1961 lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3], 1962 dim=1, 1963 expected=2), 1964 dict(lengths_a=[5, 1, 3], dim=0, expected=5), 1965 ]) 1966 def testDimStatic(self, lengths_a, dim, expected): 1967 if context.executing_eagerly(): 1968 return 1969 shape_a = DynamicRaggedShape.from_lengths(lengths_a) 1970 result = tensor_util.constant_value(shape_a[dim]) 1971 self.assertEqual(result, expected) 1972 1973 @parameterized.parameters([ 1974 dict( 1975 lengths_a=[5, (1, 4, 2, 1, 3), 3], 1976 shape_e=[5, (1, 4, 2, 1, 3), 3], 1977 new_num_row_partitions=2), # Fails 1978 dict( 1979 lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3], 1980 shape_e=[3, 2, (1, 4, 2, 1, 3, 1), 3], 1981 new_num_row_partitions=3), # Fails 1982 ]) 1983 def testNumRowPartitionShapeStatic(self, 1984 lengths_a, 1985 shape_e, 1986 new_num_row_partitions, 1987 num_row_partitions_a=None): 1988 if context.executing_eagerly(): 1989 return 1990 shape_a = DynamicRaggedShape.from_lengths( 1991 lengths_a, num_row_partitions=num_row_partitions_a) 1992 result = shape_a._with_num_row_partitions(new_num_row_partitions) 1993 self.assertEqual(shape_e, result.static_lengths()) 1994 1995 @parameterized.parameters([ 1996 dict(lengths_a=[5, (1, 4, 2, 1, 3), 3]), 1997 dict(lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3]), 1998 ]) 1999 def testFromLengthsNRowsStatic(self, lengths_a): 2000 if context.executing_eagerly(): 2001 return 2002 shape_a = DynamicRaggedShape.from_lengths(lengths_a) 2003 for rp in shape_a.row_partitions: 2004 actual = tensor_util.constant_value(rp.nrows()) 2005 self.assertIsNotNone(actual, 'Failed on ' + str(rp)) 2006 2007 @parameterized.parameters([ 2008 dict( 2009 lengths_a=[5, (1, 4, 2, 1, 3), 3], inner_shape=[33], 2010 new_inner_rank=1), 2011 dict( 2012 lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3], 2013 inner_shape=[36], 2014 new_inner_rank=1), 2015 dict( 2016 lengths_a=[3, 2, (1, 4, 2, 1, 3, 1), 3, 4], 2017 inner_shape=[36, 4], 2018 new_inner_rank=2), 2019 ]) 2020 def testAltInnerShapeStatic(self, 2021 lengths_a, 2022 inner_shape, 2023 new_inner_rank, 2024 num_row_partitions_a=None): 2025 if context.executing_eagerly(): 2026 return 2027 shape_a = DynamicRaggedShape.from_lengths( 2028 lengths_a, num_row_partitions=num_row_partitions_a) 2029 result = shape_a._alt_inner_shape(new_inner_rank) 2030 result_static = tensor_util.constant_value_as_shape(result) 2031 self.assertEqual(inner_shape, result_static.as_list()) 2032 2033 @parameterized.parameters([ 2034 dict( 2035 lengths=[3, (1, 4, 2)], 2036 shape_e=[3, None]), 2037 dict( 2038 lengths=[3, (1, 4, 2)], 2039 shape_e=[3, None]), 2040 dict( 2041 lengths=[5, (1, 4, 2, 1, 3), 3], 2042 shape_e=[5, None, 3]), 2043 dict( 2044 lengths=[5, (1, 4, 2, 1, 3), 3], 2045 shape_e=[5, None, 3]), 2046 dict( 2047 lengths=[3, 2, (1, 4, 2, 1, 3, 1), 3], 2048 shape_e=[3, 2, None, 3]), 2049 dict( 2050 lengths=[3, 2, (1, 4, 2, 1, 3, 1), 3], 2051 shape_e=[3, 2, None, 3]), 2052 ]) 2053 def testStaticShape(self, 2054 lengths, 2055 shape_e, 2056 num_row_partitions=None): 2057 # Testing the shape has enough information. 2058 # In particular, any uniform_row_length should be reproduced. 2059 if context.executing_eagerly(): 2060 return 2061 shape = DynamicRaggedShape.from_lengths( 2062 lengths, num_row_partitions=num_row_partitions) 2063 rt_a = ragged_array_ops.ragged_reshape( 2064 _lowest_primes(_num_elements_of_lengths(lengths)), shape) 2065 shape_e = tensor_shape.TensorShape(shape_e) 2066 self.assertEqual(shape_e.as_list(), rt_a.shape.as_list()) 2067 2068 @parameterized.parameters([ 2069 dict( 2070 lengths=[5, (1, 4, 2, 1, 3), 3], 2071 shape_e=[5, (1, 4, 2, 1, 3), 3]), 2072 dict( 2073 lengths=[3, 2, (1, 4, 2, 1, 3, 1), 3], 2074 shape_e=[3, 2, (1, 4, 2, 1, 3, 1), 3]), 2075 ]) 2076 def testWithNumRowPartitionsStatic(self, 2077 lengths, 2078 shape_e, 2079 num_row_partitions=None): 2080 # Note that this test loses the later static values. 2081 if context.executing_eagerly(): 2082 return 2083 shape = DynamicRaggedShape.from_lengths( 2084 lengths, num_row_partitions=num_row_partitions) 2085 shape_b = shape._with_num_row_partitions(shape.rank - 1) 2086 self.assertEqual(shape_e, shape_b.static_lengths()) 2087 2088 def testWithNumRowPartitionsStaticAlt(self): 2089 # Note that this test loses the later static values. 2090 if context.executing_eagerly(): 2091 return 2092 shape = DynamicRaggedShape.from_lengths( 2093 [5, 2, 3], num_row_partitions=2) 2094 shape_b = shape._with_num_row_partitions(0) 2095 self.assertEqual([5, 2, 3], shape_b.static_lengths()) 2096 2097 def testWithNumRowPartitionsDType(self): 2098 # Note that this test loses the later static values. 2099 shape = DynamicRaggedShape([], constant_op.constant([5, 2, 3], 2100 dtype=dtypes.int32)) 2101 self.assertEqual(shape.dtype, dtypes.int32) 2102 2103 result = shape._with_num_row_partitions(2) 2104 self.assertEqual(result.dtype, dtypes.int32) 2105 2106 def test_merge_with(self): 2107 original = DynamicRaggedShape.from_lengths([2, (3, 5), 6]) 2108 result = original._merge_with(original) 2109 self.assertShapeEq(result, original) 2110 2111 def test_merge_with_spec(self): 2112 original = DynamicRaggedShape.from_lengths([2, (3, 5), 6], 2113 dtype=dtypes.int64) 2114 spec = DynamicRaggedShape.Spec( 2115 row_partitions=[ 2116 RowPartitionSpec(nrows=2, 2117 nvals=8, 2118 dtype=dtypes.int64) 2119 ], 2120 static_inner_shape=tensor_shape.TensorShape([8, 6]), 2121 dtype=dtypes.int64) 2122 result = original._merge_with_spec(spec) 2123 self.assertShapeEq(result, original) 2124 2125 def test_merge_with_spec_raises(self): 2126 original = DynamicRaggedShape.from_lengths([2, (3, 5), 6], 2127 dtype=dtypes.int64) 2128 spec = DynamicRaggedShape.Spec( 2129 row_partitions=[ 2130 RowPartitionSpec(nrows=2, 2131 nvals=8, 2132 dtype=dtypes.int32) 2133 ], 2134 static_inner_shape=tensor_shape.TensorShape([8, 6]), 2135 dtype=dtypes.int32) 2136 with self.assertRaisesRegex( 2137 ValueError, 2138 'RowPartition and RowPartitionSpec are not compatible'): 2139 original._merge_with_spec(spec) 2140 2141 def test_merge_with_spec_uniform(self): 2142 original = DynamicRaggedShape.from_lengths( 2143 [2, (4, 4), 6], dtype=dtypes.int64) 2144 spec = DynamicRaggedShape.Spec._from_tensor_shape( 2145 tensor_shape.TensorShape([2, 4, 6]), 2146 num_row_partitions=0, 2147 dtype=dtypes.int64) 2148 result = original._merge_with_spec(spec) 2149 original = DynamicRaggedShape.from_lengths([2, 4, 6], 2150 num_row_partitions=1, 2151 dtype=dtypes.int64) 2152 self.assertShapeEq(result, original) 2153 2154 @parameterized.parameters([ 2155 dict( 2156 doc='x.shape=[3, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]', 2157 x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]], 2158 dtype=np.int32), 2159 y=[[10], [20], [30]], 2160 expected=ragged_factory_ops.constant_value([[11, 12, 13], [], 2161 [34, 35]])), 2162 dict( 2163 doc='x.shape=[3, (D1)]; y.shape=[]; bcast.shape=[3, (D1)]', 2164 x=ragged_factory_ops.constant_value([[1, 2, 3], [], [4, 5]], 2165 dtype=np.int32), 2166 y=10, 2167 expected=ragged_factory_ops.constant_value([[11, 12, 13], [], 2168 [14, 15]])), 2169 dict( 2170 doc='x.shape=[1, (D1)]; y.shape=[3, 1]; bcast.shape=[3, (D1)]', 2171 x=ragged_factory_ops.constant_value([[1, 2, 3]], dtype=np.int32), 2172 y=[[10], [20], [30]], 2173 expected=ragged_factory_ops.constant_value( 2174 [[11, 12, 13], [21, 22, 23], [31, 32, 33]], dtype=np.int32)), 2175 dict( 2176 doc=('x.shape=[2, (D1), 1]; y.shape=[1, (D2)]; ' 2177 'bcast.shape=[2, (D1), (D2)]'), 2178 x=ragged_factory_ops.constant_value([[[1], [2], [3]], [[4]]], 2179 ragged_rank=1), 2180 y=ragged_factory_ops.constant_value([[10, 20, 30]]), 2181 expected=ragged_factory_ops.constant_value([[[11, 21, 2182 31], [12, 22, 32], 2183 [13, 23, 33]], 2184 [[14, 24, 34]]])), 2185 dict( 2186 doc=('x.shape=[2, (D1), 1]; y.shape=[1, 1, 4]; ' 2187 'bcast.shape=[2, (D1), 4]'), 2188 x=ragged_factory_ops.constant_value([[[10], [20]], [[30]]], 2189 ragged_rank=1), 2190 y=[[[1, 2, 3, 4]]], 2191 expected=ragged_factory_ops.constant_value( 2192 [[[11, 12, 13, 14], [21, 22, 23, 24]], [[31, 32, 33, 34]]], 2193 ragged_rank=1)), 2194 dict( 2195 doc=('x.shape=[2, (D1), 2, 1]; y.shape=[2, (D2)]; ' 2196 'bcast.shape=[2, (D1), (2), (D2)'), 2197 x=ragged_factory_ops.constant_value( 2198 [[[[1], [2]], [[3], [4]]], [[[5], [6]]]], ragged_rank=1), 2199 y=ragged_factory_ops.constant_value([[10, 20], [30]]), 2200 expected=ragged_factory_ops.constant_value([[[[11, 21], [32]], 2201 [[13, 23], [34]]], 2202 [[[15, 25], [36]]]])), 2203 ]) 2204 def testRaggedDispatchImplWithBroadcasting(self, x, y, expected, doc): 2205 expected_rrank = getattr(expected, 'num_row_partitions', 0) 2206 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32) 2207 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32) 2208 result = dynamic_ragged_shape.ragged_binary_elementwise_op_impl( 2209 gen_math_ops.add_v2, x, y) 2210 result_rrank = getattr(result, 'num_row_partitions', 0) 2211 self.assertEqual(expected_rrank, result_rrank) 2212 if hasattr(expected, 'tolist'): 2213 expected = expected.tolist() 2214 self.assertAllEqual(result, expected) 2215 2216 def testDimensions(self): 2217 a = DynamicRaggedShape._from_inner_shape([1, 2, 3]) 2218 self.assertAllEqual(1, a._dimension(0)) 2219 2220 def testGetItemIsInstanceTensor(self): 2221 a = dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape([1, 2, 3]) 2222 self.assertIsInstance(a[0], ops.Tensor) 2223 2224 @parameterized.parameters([ 2225 dict( 2226 lengths=[2, 2], 2227 num_row_partitions=1, 2228 expected=[2, 2]), 2229 dict(lengths=[2, 2], num_row_partitions=0, expected=[2, 2]), 2230 dict( 2231 lengths=[2, (1, 2), 2], num_row_partitions=1, expected=[2, (1, 2), 2]) 2232 ]) 2233 def testStaticLengths(self, 2234 lengths, 2235 num_row_partitions, 2236 expected, 2237 expected_eager=None): 2238 a = DynamicRaggedShape.from_lengths(lengths)._with_num_row_partitions( 2239 num_row_partitions) 2240 actual = a.static_lengths() 2241 if context.executing_eagerly() and expected_eager is not None: 2242 self.assertAllEqual(expected_eager, actual) 2243 else: 2244 self.assertAllEqual(expected, actual) 2245 2246 def testStaticLengthsUnknown(self): 2247 2248 @def_function.function( 2249 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2250 def foo(row_lengths): 2251 a = DynamicRaggedShape([RowPartition.from_row_lengths(row_lengths)], [6]) 2252 actual = a.static_lengths() 2253 self.assertAllEqual([None, None], actual) 2254 2255 foo([3, 3]) 2256 2257 def testStaticLengthsRankUnknown(self): 2258 # Note that the rank of the shape is unknown, so we can only provide a 2259 # prefix of the lengths. 2260 @def_function.function( 2261 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2262 def foo(inner_shape): 2263 a = DynamicRaggedShape([RowPartition.from_row_lengths([3, 3])], 2264 inner_shape) 2265 actual = a.static_lengths() 2266 self.assertAllEqual([2, (3, 3), ...], actual) 2267 2268 foo([6, 3]) 2269 2270 def testReprRankKnown(self): 2271 a = DynamicRaggedShape.from_lengths([2, (1, 2), 3]) 2272 actual = str(a) 2273 self.assertEqual( 2274 '<DynamicRaggedShape lengths=[2, (1, 2), 3] num_row_partitions=1>', 2275 actual) 2276 2277 def assertDimsEqual(self, x: tensor_shape.TensorShape, 2278 y: tensor_shape.TensorShape): 2279 if x.rank is None: 2280 self.assertIsNone( 2281 y.rank, 2282 'x has an unknown rank, but y does not: x={}, y={}'.format(x, y)) 2283 return 2284 self.assertIsNotNone( 2285 y.rank, 2286 'y has an unknown rank, but x does not: x={}, y={}'.format(x, y)) 2287 self.assertAllEqual(x.as_list(), y.as_list()) 2288 2289 def testToTensorShapeRankKnown(self): 2290 a = DynamicRaggedShape.from_lengths([2, (1, 2), 3]) 2291 actual = a._to_tensor_shape() 2292 self.assertDimsEqual(tensor_shape.TensorShape([2, None, 3]), actual) 2293 2294 def testReprRankUnknown(self): 2295 2296 @def_function.function( 2297 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2298 def foo(inner_shape): 2299 a = DynamicRaggedShape([RowPartition.from_row_lengths([3, 3])], 2300 inner_shape) 2301 actual = str(a) 2302 self.assertEqual( 2303 '<DynamicRaggedShape lengths=[2, (3, 3), ...] num_row_partitions=1>', 2304 actual) 2305 2306 foo([6, 3]) 2307 2308 def testToTensorShapeRankUnknown(self): 2309 @def_function.function( 2310 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2311 def foo(inner_shape): 2312 a = DynamicRaggedShape([RowPartition.from_row_lengths([3, 3])], 2313 inner_shape) 2314 actual = a._to_tensor_shape() 2315 self.assertDimsEqual( 2316 tensor_shape.TensorShape(None), actual) 2317 2318 foo([6, 3]) 2319 2320 def testBroadcastDynamicShapeExtendedRankOne(self): 2321 a = DynamicRaggedShape._from_inner_shape([1]) 2322 b = DynamicRaggedShape._from_inner_shape([3]) 2323 (c, ac, bc) = dynamic_ragged_shape.broadcast_dynamic_shape_extended(a, b) 2324 expected_c = DynamicRaggedShape._from_inner_shape([3]) 2325 self.assertShapeEq(c, expected_c) 2326 ac_result = ac.broadcast(constant_op.constant([4])) 2327 self.assertAllEqual(ac_result, [4, 4, 4]) 2328 bc_result = bc.broadcast(constant_op.constant([4, 7, 1])) 2329 self.assertAllEqual(bc_result, [4, 7, 1]) 2330 2331 def testBroadcastDynamicShapeExtendedRankOneRev(self): 2332 a = DynamicRaggedShape._from_inner_shape([3]) 2333 b = DynamicRaggedShape._from_inner_shape([1]) 2334 (c, ac, bc) = dynamic_ragged_shape.broadcast_dynamic_shape_extended(a, b) 2335 expected_c = DynamicRaggedShape._from_inner_shape([3]) 2336 self.assertShapeEq(c, expected_c) 2337 bc_result = bc.broadcast(constant_op.constant([4])) 2338 self.assertAllEqual(bc_result, [4, 4, 4]) 2339 ac_result = ac.broadcast(constant_op.constant([4, 7, 1])) 2340 self.assertAllEqual(ac_result, [4, 7, 1]) 2341 2342 def testBroadcastDynamicShapeExtendedRankOneIdentity(self): 2343 a = DynamicRaggedShape._from_inner_shape([3]) 2344 b = DynamicRaggedShape._from_inner_shape([3]) 2345 (c, ac, bc) = dynamic_ragged_shape.broadcast_dynamic_shape_extended(a, b) 2346 expected_c = DynamicRaggedShape._from_inner_shape([3]) 2347 self.assertShapeEq(c, expected_c) 2348 bc_result = bc.broadcast(constant_op.constant([4, 7, 1])) 2349 self.assertAllEqual(bc_result, [4, 7, 1]) 2350 ac_result = ac.broadcast(constant_op.constant([4, 7, 1])) 2351 self.assertAllEqual(ac_result, [4, 7, 1]) 2352 2353 def testFromGatherLayerIndexRaises(self): 2354 bad_gather_index = constant_op.constant([0.0, 0.5, 1.0]) 2355 with self.assertRaisesRegex(ValueError, 'gather_index must be'): 2356 _LayerBroadcaster.from_gather_index(bad_gather_index) 2357 2358 ### Tests mostly for code coverage ########################################### 2359 2360 def testFindPreferredDtypeIntNone(self): 2361 actual = dynamic_ragged_shape._find_dtype(3, None) 2362 self.assertIsNone(actual) 2363 2364 @parameterized.parameters([ 2365 dict( 2366 source_shape=lambda: DynamicRaggedShape._from_inner_shape([3]), 2367 target_shape=lambda: DynamicRaggedShape._from_inner_shape([3]), 2368 layer_broadcasters=lambda: [int], 2369 dtype=None, 2370 error_type=TypeError, 2371 error_regex=r'Not a LayerBroadcaster'), 2372 dict( 2373 source_shape=lambda: DynamicRaggedShape._from_inner_shape([3]), 2374 target_shape=lambda: DynamicRaggedShape._from_inner_shape([3]), 2375 layer_broadcasters=lambda: _LayerBroadcaster.from_gather_index( 2376 [0, 1, 2]), 2377 dtype=None, 2378 error_type=TypeError, 2379 error_regex=r'layer'), 2380 dict( 2381 source_shape=lambda: DynamicRaggedShape._from_inner_shape([3]), 2382 target_shape=lambda: None, 2383 layer_broadcasters=lambda: 2384 [_LayerBroadcaster.from_gather_index([0, 1, 2])], 2385 dtype=None, 2386 error_type=TypeError, 2387 error_regex='target_shape is not a DynamicRaggedShape'), 2388 dict( 2389 source_shape=lambda: None, 2390 target_shape=lambda: DynamicRaggedShape._from_inner_shape([3]), 2391 layer_broadcasters=lambda: 2392 [_LayerBroadcaster.from_gather_index([0, 1, 2])], 2393 dtype=None, 2394 error_type=TypeError, 2395 error_regex='source_shape is not a DynamicRaggedShape') 2396 ]) 2397 def testBroadcasterInitRaises(self, source_shape, target_shape, 2398 layer_broadcasters, dtype, error_type, 2399 error_regex): 2400 source_shape = source_shape() 2401 target_shape = target_shape() 2402 layer_broadcasters = layer_broadcasters() 2403 with self.assertRaisesRegex(error_type, error_regex): 2404 dynamic_ragged_shape._Broadcaster( 2405 source_shape, target_shape, layer_broadcasters, dtype=dtype) 2406 2407 def testBroadcasterRepr(self): 2408 source_shape = DynamicRaggedShape( 2409 [RowPartition.from_row_splits(constant_op.constant([0, 1, 2]))], 2410 constant_op.constant([3])) 2411 target_shape = DynamicRaggedShape( 2412 [RowPartition.from_row_splits(constant_op.constant([0, 1, 2]))], 2413 constant_op.constant([3])) 2414 layer_broadcasters = [ 2415 _LayerBroadcaster.from_gather_index(constant_op.constant([0, 1, 2])), 2416 _LayerBroadcaster.from_gather_index(constant_op.constant([0, 1, 2])) 2417 ] 2418 bc = dynamic_ragged_shape._Broadcaster(source_shape, target_shape, 2419 layer_broadcasters) 2420 actual = str(bc) 2421 self.assertRegex(actual, '.src_shape..DynamicRaggedShape') 2422 2423 def testBroadcasterWithDtype(self): 2424 source_shape = DynamicRaggedShape( 2425 [RowPartition.from_row_splits(constant_op.constant([0, 1, 2]))], 2426 constant_op.constant([3])) 2427 target_shape = DynamicRaggedShape( 2428 [RowPartition.from_row_splits(constant_op.constant([0, 1, 2]))], 2429 constant_op.constant([3])) 2430 layer_broadcasters = [ 2431 _LayerBroadcaster.from_gather_index(constant_op.constant([0, 1, 2])), 2432 _LayerBroadcaster.from_gather_index(constant_op.constant([0, 1, 2])) 2433 ] 2434 bc = dynamic_ragged_shape._Broadcaster( 2435 source_shape, target_shape, layer_broadcasters, dtype=dtypes.int32) 2436 2437 bc2 = bc.with_dtype(dtypes.int64) 2438 self.assertEqual(bc2.dtype, dtypes.int64) 2439 2440 # TODO(martinz): This doesn't work for ragged_tensor_shape. 2441 # Uncomment when we switch over the implementation. 2442 # dict(dtype=dtypes.int32) 2443 @parameterized.parameters([ 2444 dict(dtype=dtypes.int64) 2445 ]) 2446 def testBroadcasterWithDenseDType(self, dtype): 2447 a = constant_op.constant([[4]]) 2448 b = RaggedTensor.from_row_splits([[2], [3], [4], [5]], [0, 3, 4]) 2449 b = b.with_row_splits_dtype(dtype) 2450 c = a + b 2451 self.assertEqual(c.row_splits.dtype, dtype) 2452 d = b + a 2453 self.assertEqual(d.row_splits.dtype, dtype) 2454 2455 @parameterized.parameters([ 2456 dict(dtype_left=dtypes.int64, 2457 dtype_right=dtypes.int32), 2458 dict(dtype_left=dtypes.int32, 2459 dtype_right=dtypes.int64)]) 2460 def testBroadcastWithDifferentDenseShapeDTypes(self, dtype_left, 2461 dtype_right): 2462 s_left = DynamicRaggedShape._from_inner_shape( 2463 constant_op.constant([4, 1], dtype_left)) 2464 s_right = DynamicRaggedShape._from_inner_shape( 2465 constant_op.constant([1, 4], dtype_right)) 2466 s_result = dynamic_ragged_shape.broadcast_dynamic_shape(s_left, s_right) 2467 self.assertEqual(s_result.dtype, dtypes.int64) 2468 2469 def testBroadcastFlatValuesToDenseExpand(self): 2470 source = RaggedTensor.from_uniform_row_length([0, 1, 2, 3], 2) 2471 target_shape = DynamicRaggedShape._from_inner_shape([1, 2, 2]) 2472 broadcaster = dynamic_ragged_shape._get_broadcaster( 2473 DynamicRaggedShape.from_tensor(source), target_shape) 2474 flat_values = broadcaster.broadcast_flat_values(source) 2475 self.assertAllEqual(flat_values, [[[0, 1], [2, 3]]]) 2476 2477 # TODO(edloper): Confirm that this is the expected behavior. 2478 def testBroadcastFlatValuesToDenseExpandInnerDimensionsFalse(self): 2479 source = RaggedTensor.from_uniform_row_length([0, 1, 2, 3], 2) 2480 target_shape = DynamicRaggedShape._from_inner_shape([1, 2, 2]) 2481 broadcaster = dynamic_ragged_shape._get_broadcaster( 2482 DynamicRaggedShape.from_tensor(source), target_shape) 2483 flat_values = broadcaster.broadcast_flat_values( 2484 source, inner_dimensions=False) 2485 self.assertAllEqual(flat_values, [[0, 1], [2, 3]]) 2486 2487 def testGetLayerBroadcastersFromRPSRaisesTypeError(self): 2488 with self.assertRaisesRegex(TypeError, 'Not a _LayerBroadcaster'): 2489 dynamic_ragged_shape._get_layer_broadcasters_from_rps(int, [], []) 2490 2491 def testGetBroadcasterRankDrop(self): 2492 with self.assertRaisesRegex(ValueError, 'Cannot broadcast'): 2493 a = DynamicRaggedShape._from_inner_shape([3, 4, 5]) 2494 b = DynamicRaggedShape._from_inner_shape([4, 5]) 2495 dynamic_ragged_shape._get_broadcaster(a, b) 2496 2497 @parameterized.parameters([ 2498 dict( 2499 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2500 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2501 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2502 b_1=lambda: None, 2503 error_type=TypeError, 2504 error_regex='b_1 should be a RowPartition'), 2505 dict( 2506 ac_0=lambda: None, 2507 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2508 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2509 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2510 error_type=TypeError, 2511 error_regex='ac_0 should be a _LayerBroadcaster'), 2512 dict( 2513 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2514 bc_0=lambda: None, 2515 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2516 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2517 error_type=TypeError, 2518 error_regex='bc_0 should be a _LayerBroadcaster'), 2519 dict( 2520 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2521 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2522 a_1=lambda: None, 2523 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2524 error_type=TypeError, 2525 error_regex='a_1 should be a RowPartition') 2526 ]) 2527 def testBroadcastDynamicShapeNextLayerHalfRaggedRaises( 2528 self, ac_0, bc_0, a_1, b_1, error_type, error_regex): 2529 ac_0 = ac_0() 2530 bc_0 = bc_0() 2531 a_1 = a_1() 2532 b_1 = b_1() 2533 with self.assertRaisesRegex(error_type, error_regex): 2534 dynamic_ragged_shape._broadcast_dynamic_shape_next_layer_half_ragged( 2535 ac_0, bc_0, a_1, b_1) 2536 2537 @parameterized.parameters([ 2538 dict( 2539 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2540 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2541 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2542 b_1=lambda: None, 2543 error_type=TypeError, 2544 error_regex='b_1 should be a RowPartition'), 2545 dict( 2546 ac_0=lambda: None, 2547 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2548 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2549 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2550 error_type=TypeError, 2551 error_regex='ac_0 should be a _LayerBroadcaster'), 2552 dict( 2553 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2554 bc_0=lambda: None, 2555 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2556 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2557 error_type=TypeError, 2558 error_regex='bc_0 should be a _LayerBroadcaster'), 2559 dict( 2560 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2561 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2562 a_1=lambda: None, 2563 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2564 error_type=TypeError, 2565 error_regex='a_1 should be a RowPartition') 2566 ]) 2567 def testBroadcastDynamicShapeNextLayerBothUniformRaises( 2568 self, ac_0, bc_0, a_1, b_1, error_type, error_regex): 2569 ac_0 = ac_0() 2570 bc_0 = bc_0() 2571 a_1 = a_1() 2572 b_1 = b_1() 2573 with self.assertRaisesRegex(error_type, error_regex): 2574 dynamic_ragged_shape._broadcast_dynamic_shape_next_layer_both_uniform( 2575 ac_0, bc_0, a_1, b_1) 2576 2577 @parameterized.parameters([ 2578 dict( 2579 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2580 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2581 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2582 b_1=lambda: None, 2583 error_type=TypeError, 2584 error_regex='b_1 should be a RowPartition'), 2585 dict( 2586 ac_0=lambda: None, 2587 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2588 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2589 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2590 error_type=TypeError, 2591 error_regex='ac_0 should be a _LayerBroadcaster'), 2592 dict( 2593 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2594 bc_0=lambda: None, 2595 a_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2596 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2597 error_type=TypeError, 2598 error_regex='bc_0 should be a _LayerBroadcaster'), 2599 dict( 2600 ac_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2601 bc_0=lambda: _LayerBroadcaster.from_gather_index([0, 1, 2]), 2602 a_1=lambda: None, 2603 b_1=lambda: RowPartition.from_row_splits([0, 1, 2]), 2604 error_type=TypeError, 2605 error_regex='a_1 should be a RowPartition') 2606 ]) 2607 def testBroadcastDynamicShapeNextLayerRaises(self, ac_0, bc_0, a_1, b_1, 2608 error_type, error_regex): 2609 ac_0 = ac_0() 2610 bc_0 = bc_0() 2611 a_1 = a_1() 2612 b_1 = b_1() 2613 with self.assertRaisesRegex(error_type, error_regex): 2614 dynamic_ragged_shape._broadcast_dynamic_shape_next_layer( 2615 ac_0, bc_0, a_1, b_1) 2616 2617 @parameterized.parameters([ 2618 dict( 2619 left_dtype=dtypes.int64, 2620 right_dtype=dtypes.int64, 2621 expected_dtype=dtypes.int64), 2622 dict( 2623 left_dtype=dtypes.int32, 2624 right_dtype=dtypes.int32, 2625 expected_dtype=dtypes.int32) 2626 ]) 2627 def testAddingRowSplits(self, left_dtype, right_dtype, expected_dtype): 2628 x = ragged_factory_ops.constant([[1, 2]]).with_row_splits_dtype(left_dtype) 2629 y = ragged_factory_ops.constant([[1, 2]]).with_row_splits_dtype(right_dtype) 2630 z = math_ops.add(x, y) 2631 self.assertEqual(z.row_splits.dtype, expected_dtype) 2632 2633 @parameterized.parameters([ 2634 dict(left_dtype=dtypes.int32, right_dtype=dtypes.int64), 2635 dict(left_dtype=dtypes.int64, right_dtype=dtypes.int32), 2636 ]) 2637 def testAddingRowSplitsError(self, left_dtype, right_dtype): 2638 x = ragged_factory_ops.constant([[1, 2]]).with_row_splits_dtype(left_dtype) 2639 y = ragged_factory_ops.constant([[1, 2]]).with_row_splits_dtype(right_dtype) 2640 with self.assertRaisesRegex( 2641 ValueError, 'Input RaggedTensors have mismatched row_splits dtypes'): 2642 math_ops.add(x, y) 2643 2644 def testAddRowPartitionsInvalidV1(self): 2645 if not context.executing_eagerly(): 2646 return 2647 2648 with self.assertRaisesRegex( 2649 (errors_impl.InvalidArgumentError, ValueError), 2650 'Last row partition does not match flat_values.'): 2651 rt = ragged_factory_ops.constant([[3], [4, 5], [6]]) 2652 rt_shape = DynamicRaggedShape.from_tensor(rt) 2653 new_flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e']) 2654 rt_shape._add_row_partitions(new_flat_values, validate=True) 2655 2656 # Example #1: 2657 # [2, (3, 1), 5], num_row_partitions = 1, outer_axis = 0, inner_axis = 1. 2658 # Result: [4, 5], num_row_partitions = 0. 2659 # Example #2: 2660 # [2, (2, 1), (7, 8, 9), 5], num_row_partitions = 2, outer_axis = 1, 2661 # inner_axis = 2. 2662 # Result: [2, (15, 9), 5], num_row_partitions = 1. 2663 # Example #3: 2664 # [2, (2, 1), (7, 8, 9), 5], num_row_partitions = 2, outer_axis = 0, 2665 # inner_axis = 1. 2666 # Result: [(7, 8, 9), 5], num_row_partitions = 1. 2667 # Here, we are merging the tail of the row_partitions, 2668 # but the inner_shape is unchanged. 2669 2670 @parameterized.parameters([ 2671 # NOOP 2672 dict( 2673 lengths=[2, (3, 1), 5], 2674 num_row_partitions=1, 2675 outer_axis=1, 2676 inner_axis=1, 2677 expected_lengths=[2, (3, 1), 5], 2678 expected_num_row_partitions=1), 2679 # Where num_row_partitions == 0 2680 dict( 2681 lengths=[2, 7, 5, 4], 2682 num_row_partitions=0, 2683 outer_axis=1, 2684 inner_axis=2, 2685 expected_lengths=[2, 35, 4], 2686 expected_num_row_partitions=0), 2687 # Where inner_axis <= self.num_row_partitions 2688 dict( 2689 lengths=[2, (3, 1), 5], 2690 num_row_partitions=1, 2691 outer_axis=0, 2692 inner_axis=1, 2693 expected_lengths=[4, 5], 2694 expected_num_row_partitions=0), 2695 dict( 2696 lengths=[2, (2, 1), (7, 8, 9), 5], 2697 num_row_partitions=2, 2698 outer_axis=1, 2699 inner_axis=2, 2700 expected_lengths=[2, (15, 9), 5], 2701 expected_num_row_partitions=1), 2702 # outer_axis > num_row_partitions (only inner_shape changed) 2703 dict( 2704 lengths=[2, (1, 2), 5, 3], 2705 num_row_partitions=1, 2706 outer_axis=2, 2707 inner_axis=3, 2708 expected_lengths=[2, (1, 2), 15], 2709 expected_num_row_partitions=1), 2710 # outer_axis <= num_row_partitions 2711 # inner_axis > num_row_partitions (everything changes) 2712 # (If outer_axis == 0, all row_partitions are truncated). 2713 dict( 2714 lengths=[2, (2, 1), (7, 8, 9), 2, 5], 2715 num_row_partitions=2, 2716 outer_axis=0, 2717 inner_axis=3, 2718 expected_lengths=[48, 5], 2719 expected_num_row_partitions=0), 2720 dict( 2721 lengths=[2, (2, 1), (7, 8, 9), 2, 5], 2722 num_row_partitions=2, 2723 outer_axis=1, 2724 inner_axis=3, 2725 expected_lengths=[2, (30, 18), 5], 2726 expected_num_row_partitions=1), 2727 ]) 2728 def test_merge_dims(self, lengths, num_row_partitions, outer_axis, inner_axis, 2729 expected_lengths, expected_num_row_partitions): 2730 original = DynamicRaggedShape.from_lengths( 2731 lengths, num_row_partitions=num_row_partitions) 2732 actual = original._merge_dims(outer_axis, inner_axis) 2733 expected = DynamicRaggedShape.from_lengths(expected_lengths, 2734 expected_num_row_partitions) 2735 self.assertShapeEq(actual, expected) 2736 2737 def test_merge_dims_special(self): 2738 rt = ragged_factory_ops.constant([[[1, 2], [3]], [[4]]]) 2739 original = DynamicRaggedShape.from_tensor(rt) 2740 actual = original._merge_dims(0, 1) 2741 self.assertAllEqual(actual[0], 3) 2742 2743 def testGetItemRankNoneTruncate(self): 2744 @def_function.function( 2745 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2746 def foo(x): 2747 rts = DynamicRaggedShape.from_tensor(x) 2748 actual = rts[:1] 2749 self.assertShapeEq(rts, actual) 2750 2751 foo([1, 2, 3]) 2752 2753 def test_dataset_only_dense(self): 2754 ragged = DynamicRaggedShape.from_lengths([4, 5, 2, 3]) 2755 dataset_ops.DatasetV2.from_tensors(ragged) 2756 2757 def test_dataset_only_ragged(self): 2758 ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5), 2, 3]) 2759 dataset_ops.DatasetV2.from_tensors(ragged) 2760 2761 def test_ragged_dataset(self): 2762 rt = RaggedTensor.from_row_splits(array_ops.zeros([5, 2, 3]), [0, 3, 5]) 2763 dataset_ops.DatasetV2.from_tensors(rt) 2764 2765 def test_ones_shape(self): 2766 ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5)]) 2767 ones = dynamic_ragged_shape.ones(ragged, dtype=bool) 2768 sh2 = DynamicRaggedShape.from_tensor(ones) 2769 self.assertAllEqual(sh2.static_lengths(), [4, (3, 0, 4, 5)]) 2770 2771 def test_dataset_only_simple_ragged(self): 2772 ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5)]) 2773 dataset_ops.DatasetV2.from_tensors(ragged) 2774 2775 # ValueError: _to_batched_tensor_list doesn't support ragged_rank=0 yet 2776 def test_unbatch_batch_dense(self): 2777 ragged = DynamicRaggedShape.from_lengths([4, 5, 2, 3]) 2778 ds = dataset_ops.DatasetV2.from_tensors(ragged) 2779 dsu = ds.unbatch() 2780 if context.executing_eagerly(): 2781 values = list(dsu) 2782 self.assertAllEqual(values[0].static_lengths(), [5, 2, 3]) 2783 self.assertAllEqual(values[2].static_lengths(), [5, 2, 3]) 2784 2785 dsb = dsu.batch(2) 2786 if context.executing_eagerly(): 2787 valuesb = list(dsb) 2788 self.assertAllEqual(valuesb[0].static_lengths(), [2, 5, 2, 3]) 2789 self.assertAllEqual(valuesb[1].static_lengths(), [2, 5, 2, 3]) 2790 2791 def test_unbatch_batch_values_shape_0(self): 2792 batched = DynamicRaggedShape.from_lengths([2]) 2793 batch_size = 2 2794 ds = dataset_ops.Dataset.from_tensors(batched) 2795 ds2 = ds.unbatch() 2796 if context.executing_eagerly(): 2797 v = list(ds2.batch(batch_size)) 2798 self.assertAllEqual(v[0], batched) 2799 2800 def test_unbatch_batch_values_shape_1(self): 2801 batched = DynamicRaggedShape.from_lengths([2, 3]) 2802 rebatched = DynamicRaggedShape.from_lengths([2, 3], num_row_partitions=1) 2803 2804 batch_size = 2 2805 ds = dataset_ops.Dataset.from_tensors(batched) 2806 ds2 = ds.unbatch() 2807 if context.executing_eagerly(): 2808 v = list(ds2.batch(batch_size)) 2809 self.assertAllEqual(v[0], rebatched) 2810 2811 def test_unbatch_dense_matrix(self): 2812 ragged = DynamicRaggedShape.from_lengths([2, 3]) 2813 ds = dataset_ops.DatasetV2.from_tensors(ragged) 2814 dsu = ds.unbatch() 2815 if context.executing_eagerly(): 2816 values = list(dsu) 2817 self.assertAllEqual(values[0].static_lengths(), [3]) 2818 self.assertAllEqual(values[1].static_lengths(), [3]) 2819 2820 def test_unbatch_dense_vector(self): 2821 ragged = DynamicRaggedShape.from_lengths([3]) 2822 ds = dataset_ops.DatasetV2.from_tensors(ragged) 2823 dsu = ds.unbatch() 2824 if context.executing_eagerly(): 2825 values = list(dsu) 2826 self.assertAllEqual(values[0].static_lengths(), []) 2827 self.assertAllEqual(values[1].static_lengths(), []) 2828 2829 def test_unbatch_ragged(self): 2830 ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5), 2, 3]) 2831 ds = dataset_ops.DatasetV2.from_tensors(ragged) 2832 dsu = ds.unbatch() 2833 if context.executing_eagerly(): 2834 dsu.__iter__() 2835 2836 def test_unbatch_batch_ragged(self): 2837 ragged = DynamicRaggedShape.from_lengths([4, (3, 0, 4, 5), 2, 3]) 2838 ds = dataset_ops.DatasetV2.from_tensors(ragged) 2839 dsu = ds.unbatch() 2840 if context.executing_eagerly(): 2841 values = list(dsu) 2842 self.assertAllEqual(values[0].static_lengths(), [3, 2, 3]) 2843 self.assertAllEqual(values[2].static_lengths(), [4, 2, 3]) 2844 2845 dsb = dsu.batch(2) 2846 if context.executing_eagerly(): 2847 valuesb = list(dsb) 2848 self.assertAllEqual(valuesb[0].static_lengths(), [2, (3, 0), 2, 3]) 2849 self.assertAllEqual(valuesb[1].static_lengths(), [2, (4, 5), 2, 3]) 2850 2851 2852class DynamicRaggedShapeErrorTest(parameterized.TestCase): 2853 2854 @parameterized.parameters([ 2855 # Broadcast [1, 2, (1, 2)] to [1, 2, (2, 1)] (FAIL) 2856 dict( 2857 origin_lengths=[2, 1, (1, 2)], 2858 origin_values=[2, 3, 5], 2859 expected_lengths=[1, 2, (2, 1)]), 2860 # Broadcast [2, 1, (1, 1)] -> [2, 1, (5, 5)] (UNSUPPORTED) 2861 dict( 2862 origin_lengths=[2, 1, (1, 1)], 2863 origin_values=[2, 3], 2864 expected_lengths=[2, 1, (5, 5)]), 2865 # Broadcast [1, 2, (1, 2)] to [2, 2, (2, 1, 1, 2)] (FAIL) 2866 dict( 2867 origin_lengths=[1, 2, (1, 2)], 2868 origin_values=[2, 3, 5], 2869 expected_lengths=[2, 2, (2, 1, 1, 2)]), 2870 # Broadcast w.shape = [2,1,(1,3)] to w'.shape = [2,1,(3,3)] (UNSUPPORTED) 2871 dict( 2872 origin_lengths=[2, 1, (1, 3)], 2873 origin_values=[2, 3, 5, 7], # [[[2]], [[3, 5, 7]]] 2874 expected_lengths=[2, 1, (3, 3)]), 2875 ]) 2876 def testBroadcastRaggedError(self, origin_lengths, origin_values, 2877 expected_lengths): 2878 # I pulled this out of the tensorflow test case, so that I could have 2879 # more control. 2880 # However this error is being generated, it confuses assertRaises, 2881 # but it exists. 2882 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2883 r'Cannot broadcast'): 2884 # with self.assertRaisesRegex(errors.InvalidArgumentError, 2885 # r"Cannot broadcast"): 2886 sess = session.Session() 2887 with sess.as_default(): 2888 origin = _to_ragged_tensor_from_lengths(origin_values, origin_lengths) 2889 expected_shape = DynamicRaggedShape.from_lengths(expected_lengths) 2890 2891 rt = dynamic_ragged_shape.broadcast_to(origin, expected_shape) 2892 sess.run([rt]) 2893 2894 @parameterized.parameters([ 2895 # nvals and nrows don't match (3 != 4) dynamically 2896 dict( 2897 row_partitions=lambda: [ # pylint: disable=g-long-lambda 2898 RowPartition.from_uniform_row_length(1, 3, nrows=3), 2899 RowPartition.from_uniform_row_length(1, 4, nrows=4) 2900 ], 2901 inner_shape=lambda: [4], 2902 validate=True, 2903 error_regex='RowPartitions in DynamicRaggedShape do not'), 2904 # nvals and inner_shape[0] don't match (3 != 4) dynamically 2905 dict( 2906 row_partitions=lambda: [ # pylint: disable=g-long-lambda 2907 RowPartition.from_uniform_row_length(1, 3, nrows=3), 2908 ], 2909 inner_shape=lambda: [4], 2910 validate=True, 2911 error_regex='Last row partition does not match inner_shape.'), 2912 ]) 2913 def testConstructorRaisesDynamic(self, 2914 row_partitions, 2915 inner_shape, 2916 error_regex, 2917 validate=False, 2918 dtype=None): 2919 with self.assertRaisesRegex((errors_impl.InvalidArgumentError, ValueError), 2920 error_regex): 2921 sess = session.Session() 2922 with sess.as_default(): 2923 row_partitions = row_partitions() 2924 inner_shape = inner_shape() 2925 rts = DynamicRaggedShape( 2926 row_partitions, inner_shape, dtype=dtype, validate=validate) 2927 sess.run([rts.inner_shape]) 2928 2929 def testRankNone(self): 2930 2931 @def_function.function( 2932 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2933 def foo(x): 2934 rts = DynamicRaggedShape._from_inner_shape(x) 2935 self.assertIsNone(rts.rank) 2936 2937 foo([3, 7, 5]) 2938 2939 def testNumSlicesInDimensionRankNone(self): 2940 with self.assertRaisesRegex(ValueError, 'rank is undefined'): 2941 2942 @def_function.function( 2943 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2944 def foo(x): 2945 rts = DynamicRaggedShape._from_inner_shape(x) 2946 rts._num_slices_in_dimension(-1) 2947 2948 foo([3, 7, 5]) 2949 2950 def testGetItemRankNone(self): 2951 with self.assertRaisesRegex(ValueError, 'Rank must be known to'): 2952 2953 @def_function.function( 2954 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2955 def foo(x): 2956 rts = DynamicRaggedShape._from_inner_shape(x) 2957 rts[-1] # pylint: disable=pointless-statement 2958 2959 foo([3, 7, 5]) 2960 2961 def testWithDenseRankRankNone(self): 2962 with self.assertRaisesRegex(ValueError, 'Rank must be known to'): 2963 2964 @def_function.function( 2965 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2966 def foo(x): 2967 rts = DynamicRaggedShape._from_inner_shape(x) 2968 rts._with_inner_rank(1) 2969 2970 foo([3, 7, 5]) 2971 2972 def testWithRaggedRankRankNone(self): 2973 with self.assertRaisesRegex(ValueError, 'Rank must be known to'): 2974 2975 @def_function.function( 2976 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2977 def foo(x): 2978 rts = DynamicRaggedShape._from_inner_shape(x) 2979 rts._with_num_row_partitions(1) 2980 2981 foo([3, 7, 5]) 2982 2983 def testAsRowPartitionsRankNone(self): 2984 # Error is readable, but does not match strings correctly. 2985 with self.assertRaisesRegex(ValueError, ''): 2986 2987 @def_function.function( 2988 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 2989 def foo(x): 2990 rts = DynamicRaggedShape._from_inner_shape(x) 2991 rts._as_row_partitions() 2992 2993 foo([3, 7, 5]) 2994 2995 def testBroadcastDynamicShapeExtendedRankNone(self): 2996 with self.assertRaisesRegex(ValueError, 2997 'Unable to broadcast: unknown rank'): 2998 2999 @def_function.function( 3000 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 3001 def foo(x): 3002 a = DynamicRaggedShape._from_inner_shape(x) 3003 b = DynamicRaggedShape._from_inner_shape([1, 1, 1]) 3004 dynamic_ragged_shape.broadcast_dynamic_shape_extended(a, b) 3005 3006 foo([3, 7, 5]) 3007 3008 def testBroadcastDynamicShapeUnmatchedTypes6432(self): 3009 shape_int64 = DynamicRaggedShape.from_lengths([3, (0, 2, 3)], 3010 dtype=dtypes.int64) 3011 shape_int32 = DynamicRaggedShape.from_lengths([3, (0, 2, 3)], 3012 dtype=dtypes.int32) 3013 with self.assertRaisesRegex(ValueError, "Dtypes don't match"): 3014 dynamic_ragged_shape.broadcast_dynamic_shape(shape_int64, shape_int32) 3015 3016 def testBroadcastDynamicShapeUnmatchedTypes3264(self): 3017 shape_int64 = DynamicRaggedShape.from_lengths([3, (0, 2, 3)], 3018 dtype=dtypes.int64) 3019 shape_int32 = DynamicRaggedShape.from_lengths([3, (0, 2, 3)], 3020 dtype=dtypes.int32) 3021 with self.assertRaisesRegex(ValueError, "Dtypes don't match"): 3022 dynamic_ragged_shape.broadcast_dynamic_shape(shape_int32, shape_int64) 3023 3024 def testGetIdentityBroadcasterRankNone(self): 3025 with self.assertRaisesRegex(ValueError, 'Shape must have a'): 3026 3027 @def_function.function( 3028 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 3029 def foo(x): 3030 rts = DynamicRaggedShape._from_inner_shape(x) 3031 dynamic_ragged_shape._get_identity_broadcaster(rts) 3032 3033 foo([3, 7, 5]) 3034 3035 def testLayerBroadcasterRepr(self): 3036 index = constant_op.constant([0, 1, 2], name='testLayerBroadcasterRepr') 3037 lb = _LayerBroadcaster.from_gather_index(index) 3038 actual = str(lb) 3039 self.assertRegex(actual, '.*Tensor.*, shape=.3... dtype=int32.') 3040 3041 def testGetBroadcasterRankNoneLeft(self): 3042 with self.assertRaisesRegex(ValueError, 'Rank of source and target must'): 3043 3044 @def_function.function( 3045 input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)]) 3046 def foo(x): 3047 rts_a = DynamicRaggedShape._from_inner_shape(x) 3048 rts_b = DynamicRaggedShape._from_inner_shape(x) 3049 dynamic_ragged_shape._get_broadcaster(rts_a, rts_b) 3050 3051 foo([3, 7, 5]) 3052 3053 def testFromTensorDType(self): 3054 x = ragged_factory_ops.constant([[1, 2]]) 3055 self.assertEqual(x.row_splits.dtype, dtypes.int64) 3056 shape_x = DynamicRaggedShape.from_tensor(x) 3057 self.assertEqual(shape_x.dtype, dtypes.int64) 3058 3059 def testAddingRowSplits(self): 3060 x = ragged_factory_ops.constant([[1, 2]]) 3061 self.assertEqual(x.row_splits.dtype, dtypes.int64) 3062 3063 y = math_ops.add(x, x) 3064 self.assertEqual(y.row_splits.dtype, dtypes.int64) 3065 3066 def testHashingWithMask(self): 3067 inp_data = ragged_factory_ops.constant( 3068 [['omar', 'stringer', 'marlo', 'wire'], ['marlo', 'skywalker', 'wire']], 3069 dtype=dtypes.string) 3070 mask = math_ops.equal(inp_data, '') 3071 values = string_ops.string_to_hash_bucket_strong( 3072 inp_data, 3, name='hash', key=[0xDECAFCAFFE, 0xDECAFCAFFE]) 3073 values = math_ops.add(values, array_ops.ones_like(values)) 3074 local_zeros = array_ops.zeros_like(values) 3075 values = array_ops.where(mask, local_zeros, values) 3076 3077 def testAddRowPartitionsInvalid(self): 3078 with self.assertRaisesRegex( 3079 (errors_impl.InvalidArgumentError, ValueError), 3080 'Last row partition does not match flat_values.'): 3081 sess = session.Session() 3082 with sess.as_default(): 3083 rt = ragged_factory_ops.constant([[3], [4, 5], [6]]) 3084 rt_shape = DynamicRaggedShape.from_tensor(rt) 3085 new_flat_values = constant_op.constant(['a', 'b', 'c']) 3086 rt2 = rt_shape._add_row_partitions(new_flat_values, validate=True) 3087 sess.run([rt2]) 3088 3089 3090class DynamicRaggedShapeSpecTest(parameterized.TestCase): 3091 3092 def assertRowPartitionSpecEqual(self, 3093 a: RowPartitionSpec, 3094 b: RowPartitionSpec, 3095 msg='') -> None: 3096 self.assertEqual(a.nrows, b.nrows, msg) 3097 self.assertEqual(a.nvals, b.nvals, msg) 3098 self.assertEqual(a.uniform_row_length, b.uniform_row_length, msg) 3099 self.assertEqual(a.dtype, b.dtype, msg) 3100 3101 def assertTensorShapeEqual(self, a: tensor_shape.TensorShape, 3102 b: tensor_shape.TensorShape) -> None: 3103 self.assertEqual(a, b) 3104 3105 def assertTensorSpecEqual(self, 3106 a: tensor_spec.TensorSpec, 3107 b: tensor_spec.TensorSpec) -> None: 3108 self.assertTensorShapeEqual(a.shape, b.shape) 3109 self.assertEqual(a.dtype, b.dtype) 3110 3111 def assertDynamicRaggedShapeSpecEqual(self, 3112 a: DynamicRaggedShape.Spec, 3113 b: DynamicRaggedShape.Spec) -> None: 3114 self.assertTensorShapeEqual(a._static_inner_shape, b._static_inner_shape) 3115 self.assertTensorSpecEqual(a._inner_shape, b._inner_shape) 3116 for i, (a, b) in enumerate(zip(a._row_partitions, b._row_partitions)): 3117 self.assertRowPartitionSpecEqual(a, b, 'Error in partition ' + str(i)) 3118 3119 @parameterized.parameters([ 3120 # Unknown dimension 3121 dict( 3122 shape=tensor_shape.TensorShape(None), 3123 num_row_partitions=1, 3124 dtype=dtypes.int32, 3125 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3126 row_partitions=[ 3127 RowPartitionSpec( 3128 nrows=None, 3129 nvals=None, 3130 uniform_row_length=None, 3131 dtype=dtypes.int32), 3132 RowPartitionSpec( 3133 nrows=None, 3134 nvals=None, 3135 uniform_row_length=None, 3136 dtype=dtypes.int32) 3137 ], 3138 static_inner_shape=tensor_shape.TensorShape(None), 3139 dtype=dtypes.int32)), 3140 # Unknown dimension, dense 3141 dict( 3142 shape=tensor_shape.TensorShape(None), 3143 num_row_partitions=0, 3144 dtype=dtypes.int32, 3145 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3146 row_partitions=[], 3147 static_inner_shape=tensor_shape.TensorShape(None), 3148 dtype=dtypes.int32)), 3149 # Scalar 3150 dict( 3151 shape=tensor_shape.TensorShape([]), 3152 num_row_partitions=0, 3153 dtype=dtypes.int32, 3154 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3155 row_partitions=[], 3156 static_inner_shape=tensor_shape.TensorShape([]), 3157 dtype=dtypes.int32)), 3158 # Vector 3159 dict( 3160 shape=tensor_shape.TensorShape([7]), 3161 num_row_partitions=0, 3162 dtype=dtypes.int32, 3163 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3164 row_partitions=[], 3165 static_inner_shape=tensor_shape.TensorShape([7]), 3166 dtype=dtypes.int32)), 3167 # Generic 3168 dict( 3169 shape=tensor_shape.TensorShape([5, 3, None, 4, 2, 5]), 3170 num_row_partitions=3, 3171 dtype=dtypes.int32, 3172 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3173 row_partitions=[ 3174 RowPartitionSpec( 3175 nrows=5, 3176 nvals=15, 3177 uniform_row_length=3, 3178 dtype=dtypes.int32), 3179 RowPartitionSpec( 3180 nrows=15, 3181 nvals=None, 3182 uniform_row_length=None, 3183 dtype=dtypes.int32), 3184 RowPartitionSpec( 3185 nrows=None, 3186 nvals=None, 3187 uniform_row_length=4, 3188 dtype=dtypes.int32) 3189 ], 3190 static_inner_shape=tensor_shape.TensorShape([None, 2, 5]), 3191 dtype=dtypes.int32)), 3192 # Generic, Dense 3193 dict( 3194 shape=tensor_shape.TensorShape([5, 3, None, 4, 2, 5]), 3195 num_row_partitions=0, 3196 dtype=dtypes.int32, 3197 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3198 row_partitions=[], 3199 static_inner_shape=tensor_shape.TensorShape( 3200 [5, 3, None, 4, 2, 5]), 3201 dtype=dtypes.int32)), 3202 ]) 3203 def test_from_tensor_shape(self, shape, num_row_partitions, dtype, expected): 3204 spec = DynamicRaggedShape.Spec._from_tensor_shape(shape, num_row_partitions, 3205 dtype) 3206 self.assertDynamicRaggedShapeSpecEqual(spec, expected) 3207 3208 @parameterized.parameters([ 3209 # Ridiculous DType. 3210 dict( 3211 shape=tensor_shape.TensorShape(None), 3212 num_row_partitions=1, 3213 dtype=dtypes.float32, 3214 error_type=ValueError, 3215 error_regex='dtype must be tf.int32 or tf.int64'), 3216 # num_row_partitions positive for scalar. 3217 dict( 3218 shape=tensor_shape.TensorShape([]), 3219 num_row_partitions=1, 3220 dtype=dtypes.int32, 3221 error_type=ValueError, 3222 error_regex='num_row_partitions should be zero ' + 3223 'if shape is a scalar or vector.'), 3224 dict( 3225 shape=tensor_shape.TensorShape([1, 2, 3]), 3226 num_row_partitions=3, 3227 dtype=dtypes.int32, 3228 error_type=ValueError, 3229 error_regex='num_row_partitions must be less than rank') 3230 ]) 3231 def test_from_tensor_shape_raises(self, shape, num_row_partitions, dtype, 3232 error_type, error_regex): 3233 with self.assertRaisesRegex(error_type, error_regex): 3234 DynamicRaggedShape.Spec._from_tensor_shape(shape, num_row_partitions, 3235 dtype) 3236 3237 def test_from_tensor_shape_raises_dtype(self): 3238 with self.assertRaisesRegex(ValueError, 3239 'dtype must be tf.int32 or tf.int64'): 3240 DynamicRaggedShape.Spec._from_tensor_shape( 3241 [], tensor_shape.TensorShape([1, 2, 3]), dtypes.float32) 3242 3243 def test_from_row_partition_inner_shape_and_dtype_raises_dtype(self): 3244 with self.assertRaisesRegex( 3245 ValueError, r'dtype of .* is .*int64.*: expected .*int32.*'): 3246 DynamicRaggedShape.Spec( 3247 row_partitions=[ 3248 RowPartitionSpec( 3249 nrows=None, 3250 nvals=None, 3251 uniform_row_length=None, 3252 dtype=dtypes.int32), 3253 RowPartitionSpec( 3254 nrows=None, 3255 nvals=None, 3256 uniform_row_length=None, 3257 dtype=dtypes.int64) 3258 ], 3259 static_inner_shape=tensor_shape.TensorShape(None), 3260 dtype=dtypes.int32) 3261 3262 def test_ranks(self): 3263 spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3264 shape=tensor_shape.TensorShape([5, None, 7, 4, 2, 5]), 3265 num_row_partitions=2, 3266 dtype=dtypes.int32) 3267 3268 self.assertEqual(spec.inner_rank, 4) 3269 self.assertEqual(spec.num_row_partitions, 2) 3270 self.assertEqual(spec.rank, 6) 3271 3272 def test_dimension_simple(self): 3273 spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3274 shape=tensor_shape.TensorShape([5, None, 7, 4, 2, 5]), 3275 num_row_partitions=2, 3276 dtype=dtypes.int32) 3277 3278 self.assertEqual(spec._dimension(0), 5) 3279 self.assertIsNone(spec._dimension(1)) 3280 self.assertEqual(spec._dimension(2), 7) 3281 self.assertEqual(spec._dimension(3), 4) 3282 self.assertEqual(spec._dimension(4), 2) 3283 self.assertEqual(spec._dimension(5), 5) 3284 3285 @parameterized.parameters([ 3286 dict( 3287 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3288 None, 0, dtypes.int32), 3289 dimension=0), 3290 dict( 3291 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3292 None, 0, dtypes.int32), 3293 dimension=1), 3294 ]) 3295 def test_dimension_none(self, spec, dimension): 3296 actual = spec._dimension(dimension) 3297 self.assertIsNone(actual) 3298 3299 @parameterized.parameters([ 3300 # Scalar. 3301 dict( 3302 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3303 [], 0, dtypes.int32), 3304 dimension=0, 3305 error_type=ValueError, 3306 error_regex='Index out of range: 0.'), 3307 # Scalar. 3308 dict( 3309 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3310 [], 0, dtypes.int32), 3311 dimension=1, 3312 error_type=ValueError, 3313 error_regex='Index out of range: 1.'), 3314 ]) 3315 def test_dimension_raises(self, spec, dimension, error_type, error_regex): 3316 with self.assertRaisesRegex(error_type, error_regex): 3317 spec._dimension(dimension) 3318 3319 def test_num_slices_in_dimension_ragged(self): 3320 spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3321 shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]), 3322 num_row_partitions=2, 3323 dtype=dtypes.int32) 3324 3325 self.assertEqual(spec._num_slices_in_dimension(0), 5) 3326 self.assertEqual(spec._num_slices_in_dimension(1), 5 * 3) 3327 self.assertEqual(spec._num_slices_in_dimension(2), 5 * 3 * 7) 3328 self.assertEqual(spec._num_slices_in_dimension(3), 5 * 3 * 7 * 4) 3329 self.assertIsNone(spec._num_slices_in_dimension(4)) 3330 self.assertIsNone(spec._num_slices_in_dimension(5)) 3331 self.assertIsNone(spec._num_slices_in_dimension(-2)) 3332 3333 def test_num_slices_in_dimension_ragged_alt(self): 3334 spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3335 shape=tensor_shape.TensorShape([5, 3, None, 2]), 3336 num_row_partitions=3, 3337 dtype=dtypes.int32) 3338 3339 self.assertEqual(spec._num_slices_in_dimension(0), 5) 3340 self.assertEqual(spec._num_slices_in_dimension(1), 5 * 3) 3341 self.assertIsNone(spec._num_slices_in_dimension(2)) 3342 self.assertIsNone(spec._num_slices_in_dimension(3)) 3343 3344 def test_num_slices_in_dimension_dense_known(self): 3345 spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3346 [5, 3, 4], 0, dtypes.int32) 3347 3348 self.assertEqual(spec._num_slices_in_dimension(0), 5) 3349 self.assertEqual(spec._num_slices_in_dimension(1), 15) 3350 self.assertEqual(spec._num_slices_in_dimension(2), 60) 3351 3352 @parameterized.parameters([ 3353 dict( 3354 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3355 None, 0, dtypes.int32), 3356 dimension='CRAZY', 3357 error_type=TypeError, 3358 error_regex='axis must be an integer'), 3359 dict( 3360 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3361 None, 0, dtypes.int32), 3362 dimension=-1, 3363 error_type=ValueError, 3364 error_regex='axis=-1 may only be negative' + 3365 ' if rank is statically known.') 3366 ]) 3367 def test_num_slices_in_dimension_raises(self, spec, dimension, error_type, 3368 error_regex): 3369 with self.assertRaisesRegex(error_type, error_regex): 3370 spec._num_slices_in_dimension(dimension) 3371 3372 def test_with_dtype(self): 3373 spec = DynamicRaggedShape.Spec._from_tensor_shape( 3374 shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]), 3375 num_row_partitions=2, 3376 dtype=dtypes.int32) 3377 actual = spec.with_dtype(dtypes.int64) 3378 self.assertEqual(actual.dtype, dtypes.int64) 3379 self.assertEqual(actual._row_partitions[0].dtype, dtypes.int64) 3380 self.assertEqual(actual._row_partitions[1].dtype, dtypes.int64) 3381 3382 @parameterized.parameters([ 3383 dict( 3384 original=DynamicRaggedShape.Spec._from_tensor_shape( 3385 shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]), 3386 num_row_partitions=2, 3387 dtype=dtypes.int32), 3388 num_row_partitions=3, 3389 expected=DynamicRaggedShape.Spec._from_tensor_shape( 3390 shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]), 3391 num_row_partitions=3, 3392 dtype=dtypes.int32)), 3393 dict( 3394 original=DynamicRaggedShape.Spec._from_tensor_shape( 3395 shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]), 3396 num_row_partitions=2, 3397 dtype=dtypes.int32), 3398 num_row_partitions=1, 3399 expected=DynamicRaggedShape.Spec._from_tensor_shape( 3400 shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]), 3401 num_row_partitions=1, 3402 dtype=dtypes.int32)), 3403 ]) 3404 def test_with_num_row_partitions(self, original, num_row_partitions, 3405 expected): 3406 actual = original._with_num_row_partitions(num_row_partitions) 3407 self.assertDynamicRaggedShapeSpecEqual(actual, expected) 3408 3409 @parameterized.parameters([ 3410 dict( 3411 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3412 None, 0, dtypes.int32), 3413 num_row_partitions=2, 3414 error_type=ValueError, 3415 error_regex='Changing num_row_partitions with unknown rank'), 3416 dict( 3417 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3418 [1, 2, 3, 4], 0, dtypes.int32), 3419 num_row_partitions=4, 3420 error_type=ValueError, 3421 error_regex='Number of row partitions too large'), 3422 dict( 3423 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3424 [1, 2, 3, 4], 0, dtypes.int32), 3425 num_row_partitions=-3, 3426 error_type=ValueError, 3427 error_regex='Number of row partitions negative'), 3428 ]) 3429 def test_with_num_row_partitions_raises(self, spec, num_row_partitions, 3430 error_type, error_regex): 3431 with self.assertRaisesRegex(error_type, error_regex): 3432 spec._with_num_row_partitions(num_row_partitions) 3433 3434 def test_truncate(self): 3435 spec = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 3436 shape=tensor_shape.TensorShape([5, 3, 7, 4, None, 5]), 3437 num_row_partitions=2, 3438 dtype=dtypes.int32) 3439 3440 for new_rank in range(7): 3441 truncation = spec._truncate(new_rank) 3442 self.assertEqual(truncation.rank, new_rank) 3443 for i in range(new_rank): 3444 self.assertEqual( 3445 truncation._dimension(i), spec._dimension(i), 3446 'Mismatch on new_rank ' + str(new_rank) + ' on dimension ' + str(i)) 3447 3448 def test_truncate_unknown(self): 3449 spec = DynamicRaggedShape.Spec( 3450 row_partitions=[ 3451 RowPartitionSpec( 3452 nrows=3, nvals=7, uniform_row_length=None, dtype=dtypes.int32), 3453 RowPartitionSpec( 3454 nrows=7, 3455 nvals=None, 3456 uniform_row_length=None, 3457 dtype=dtypes.int32) 3458 ], 3459 static_inner_shape=tensor_shape.TensorShape(None), 3460 dtype=dtypes.int32) 3461 expected = DynamicRaggedShape.Spec( 3462 row_partitions=[ 3463 RowPartitionSpec( 3464 nrows=3, nvals=7, uniform_row_length=None, dtype=dtypes.int32), 3465 RowPartitionSpec( 3466 nrows=7, 3467 nvals=None, 3468 uniform_row_length=None, 3469 dtype=dtypes.int32) 3470 ], 3471 static_inner_shape=tensor_shape.TensorShape([None, None]), 3472 dtype=dtypes.int32) 3473 actual = spec._truncate(4) 3474 self.assertDynamicRaggedShapeSpecEqual(actual, expected) 3475 3476 @parameterized.parameters([ 3477 # Standard scalar 3478 dict( 3479 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3480 row_partitions=[], 3481 static_inner_shape=tensor_shape.TensorShape([]), 3482 dtype=dtypes.int32), 3483 expected=0), 3484 dict( 3485 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3486 row_partitions=[ 3487 RowPartitionSpec( 3488 nrows=None, 3489 nvals=None, 3490 uniform_row_length=None, 3491 dtype=dtypes.int64) 3492 ], 3493 static_inner_shape=tensor_shape.TensorShape([None]), 3494 dtype=dtypes.int64), 3495 expected=1), 3496 # Not knowing the shape of the inner shape is weird. 3497 dict( 3498 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3499 row_partitions=[ 3500 RowPartitionSpec( 3501 nrows=None, 3502 nvals=None, 3503 uniform_row_length=None, 3504 dtype=dtypes.int64) 3505 ], 3506 static_inner_shape=tensor_shape.TensorShape(None), 3507 dtype=dtypes.int64), 3508 expected=None), 3509 ]) 3510 def test_inner_rank(self, spec, expected): 3511 actual = spec.inner_rank 3512 self.assertEqual(expected, actual) 3513 3514 @parameterized.parameters([ 3515 # Standard scalar 3516 dict( 3517 other_spec=tensor_spec.TensorSpec([], dtypes.float32), 3518 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3519 row_partitions=[], 3520 static_inner_shape=tensor_shape.TensorShape([]), 3521 dtype=dtypes.int64)), 3522 dict( 3523 other_spec=ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32), 3524 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3525 row_partitions=[ 3526 RowPartitionSpec(nrows=None, 3527 nvals=None, 3528 uniform_row_length=None, 3529 dtype=dtypes.int64) 3530 ], 3531 static_inner_shape=tensor_shape.TensorShape([None]), 3532 dtype=dtypes.int64)), 3533 dict( 3534 other_spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3535 row_partitions=[ 3536 RowPartitionSpec(nrows=None, 3537 nvals=None, 3538 uniform_row_length=None, 3539 dtype=dtypes.int64) 3540 ], 3541 static_inner_shape=tensor_shape.TensorShape([None]), 3542 dtype=dtypes.int64), 3543 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3544 row_partitions=[ 3545 RowPartitionSpec(nrows=None, 3546 nvals=None, 3547 uniform_row_length=None, 3548 dtype=dtypes.int64) 3549 ], 3550 static_inner_shape=tensor_shape.TensorShape([None]), 3551 dtype=dtypes.int64)), 3552 ]) 3553 def test_from_spec(self, other_spec, expected): 3554 actual = DynamicRaggedShape.Spec._from_spec(other_spec) 3555 self.assertDynamicRaggedShapeSpecEqual(expected, actual) 3556 3557 @parameterized.parameters([ 3558 dict( 3559 row_partitions=[ 3560 RowPartitionSpec( 3561 nrows=None, 3562 nvals=None, 3563 uniform_row_length=None, 3564 dtype=dtypes.int64) 3565 ], 3566 static_inner_shape=tensor_shape.TensorShape([None]), 3567 inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)), 3568 dict( 3569 row_partitions=[ 3570 RowPartitionSpec( 3571 nrows=None, 3572 nvals=None, 3573 uniform_row_length=None, 3574 dtype=dtypes.int64) 3575 ], 3576 static_inner_shape=tensor_shape.TensorShape([None, 3]), 3577 inner_shape=tensor_spec.TensorSpec([2], dtypes.int64)), 3578 dict( 3579 row_partitions=[ 3580 RowPartitionSpec( 3581 nrows=6, 3582 nvals=None, 3583 uniform_row_length=None, 3584 dtype=dtypes.int64) 3585 ], 3586 static_inner_shape=tensor_shape.TensorShape([None]), 3587 inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)), 3588 dict( 3589 row_partitions=[ 3590 RowPartitionSpec( 3591 nrows=6, nvals=60, uniform_row_length=10, dtype=dtypes.int64) 3592 ], 3593 static_inner_shape=tensor_shape.TensorShape([60]), 3594 inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)), 3595 dict( 3596 row_partitions=[ 3597 RowPartitionSpec( 3598 nrows=6, nvals=60, uniform_row_length=10, dtype=dtypes.int64), 3599 RowPartitionSpec( 3600 nrows=60, 3601 nvals=120, 3602 uniform_row_length=None, 3603 dtype=dtypes.int64) 3604 ], 3605 static_inner_shape=tensor_shape.TensorShape([120]), 3606 inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)), 3607 dict( 3608 row_partitions=[ 3609 RowPartitionSpec( 3610 nrows=6, nvals=60, uniform_row_length=10, dtype=dtypes.int64) 3611 ], 3612 static_inner_shape=tensor_shape.TensorShape(None), 3613 inner_shape=tensor_spec.TensorSpec([None], dtypes.int64)) 3614 ]) 3615 def test_constructor_idempotent(self, row_partitions, static_inner_shape, 3616 inner_shape): 3617 # The constructor detects if there is any additional information that 3618 # can be inferred from what is given. 3619 original = dynamic_ragged_shape.DynamicRaggedShape.Spec( 3620 row_partitions, static_inner_shape, inner_shape.dtype) 3621 self.assertTensorShapeEqual(original._static_inner_shape, 3622 static_inner_shape) 3623 self.assertTensorSpecEqual(original._inner_shape, inner_shape) 3624 for i, (a, b) in enumerate(zip(original._row_partitions, row_partitions)): 3625 self.assertRowPartitionSpecEqual(a, b, 'Error in partition ' + str(i)) 3626 3627 @parameterized.parameters([ 3628 dict( 3629 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3630 row_partitions=[ 3631 RowPartitionSpec( 3632 nrows=3, 3633 nvals=None, 3634 uniform_row_length=4, 3635 dtype=dtypes.int64) 3636 ], 3637 static_inner_shape=tensor_shape.TensorShape([None]), 3638 dtype=dtypes.int64), 3639 expected_row_partitions=[ 3640 RowPartitionSpec( 3641 nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64) 3642 ], 3643 expected_static_inner_shape=tensor_shape.TensorShape([12]), 3644 expected_inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)), 3645 dict( 3646 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3647 row_partitions=[ 3648 RowPartitionSpec( 3649 nrows=None, 3650 nvals=None, 3651 uniform_row_length=3, 3652 dtype=dtypes.int64) 3653 ], 3654 static_inner_shape=tensor_shape.TensorShape([30]), 3655 dtype=dtypes.int64), 3656 expected_row_partitions=[ 3657 RowPartitionSpec( 3658 nrows=10, nvals=30, uniform_row_length=3, dtype=dtypes.int64) 3659 ], 3660 expected_static_inner_shape=tensor_shape.TensorShape([30]), 3661 expected_inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)), 3662 dict( 3663 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3664 row_partitions=[ 3665 RowPartitionSpec( 3666 nrows=6, 3667 nvals=None, 3668 uniform_row_length=10, 3669 dtype=dtypes.int64) 3670 ], 3671 static_inner_shape=tensor_shape.TensorShape([None]), 3672 dtype=dtypes.int64), 3673 expected_row_partitions=[ 3674 RowPartitionSpec( 3675 nrows=6, nvals=60, uniform_row_length=10, dtype=dtypes.int64) 3676 ], 3677 expected_static_inner_shape=tensor_shape.TensorShape([60]), 3678 expected_inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)), 3679 dict( 3680 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3681 row_partitions=[ 3682 RowPartitionSpec( 3683 nrows=6, 3684 nvals=None, 3685 uniform_row_length=None, 3686 dtype=dtypes.int64), 3687 RowPartitionSpec( 3688 nrows=60, 3689 nvals=None, 3690 uniform_row_length=None, 3691 dtype=dtypes.int64) 3692 ], 3693 static_inner_shape=tensor_shape.TensorShape([120]), 3694 dtype=dtypes.int64), 3695 expected_row_partitions=[ 3696 RowPartitionSpec( 3697 nrows=6, 3698 nvals=60, 3699 uniform_row_length=None, 3700 dtype=dtypes.int64), 3701 RowPartitionSpec( 3702 nrows=60, 3703 nvals=120, 3704 uniform_row_length=None, 3705 dtype=dtypes.int64) 3706 ], 3707 expected_static_inner_shape=tensor_shape.TensorShape([120]), 3708 expected_inner_shape=tensor_spec.TensorSpec([1], dtypes.int64)), 3709 ]) 3710 def test_constructor_improvements(self, original, expected_row_partitions, 3711 expected_static_inner_shape, 3712 expected_inner_shape): 3713 # Note that self_merge is only idempotent if no data is partially present. 3714 self.assertTensorShapeEqual(original._static_inner_shape, 3715 expected_static_inner_shape) 3716 self.assertTensorSpecEqual(original._inner_shape, expected_inner_shape) 3717 for i, (a, b) in enumerate( 3718 zip(original._row_partitions, expected_row_partitions)): 3719 self.assertRowPartitionSpecEqual(a, b, 'Error in partition ' + str(i)) 3720 3721 @parameterized.parameters([ 3722 dict( 3723 row_partitions=[ 3724 RowPartitionSpec( 3725 nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64) 3726 ], 3727 static_inner_shape=tensor_shape.TensorShape([]), 3728 dtype=dtypes.int64, 3729 error_type=ValueError, 3730 msg='If row_partitions are provided, must have inner_rank > 0'), 3731 dict( 3732 row_partitions=RowPartitionSpec( 3733 nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64), 3734 static_inner_shape=tensor_shape.TensorShape([]), 3735 dtype=dtypes.int64, 3736 error_type=TypeError, 3737 msg='row_partitions should be an Iterable'), 3738 dict( 3739 row_partitions=[1, 2, 3], 3740 static_inner_shape=tensor_shape.TensorShape([12]), 3741 dtype=dtypes.int64, 3742 error_type=TypeError, 3743 msg='row_partitions should be an Iterable of RowPartitionSpecs'), 3744 dict( 3745 row_partitions=[ 3746 RowPartitionSpec( 3747 nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64) 3748 ], 3749 static_inner_shape=3, 3750 dtype=dtypes.int64, 3751 error_type=ValueError, 3752 msg='Dimensions 12 and 3'), 3753 dict( 3754 row_partitions=[ 3755 RowPartitionSpec( 3756 nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64) 3757 ], 3758 static_inner_shape=tensor_shape.TensorShape([2]), 3759 dtype=456, 3760 error_type=TypeError, 3761 msg='Cannot convert'), 3762 dict( 3763 row_partitions=[ 3764 RowPartitionSpec( 3765 nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64) 3766 ], 3767 static_inner_shape=tensor_shape.TensorShape([12]), 3768 dtype=dtypes.int32, 3769 error_type=ValueError, 3770 msg='dtype of RowPartitionSpec'), 3771 dict( 3772 row_partitions=[ 3773 RowPartitionSpec( 3774 nrows=3, nvals=12, uniform_row_length=4, dtype=dtypes.int64) 3775 ], 3776 static_inner_shape=tensor_shape.TensorShape([11]), 3777 dtype=dtypes.int64, 3778 error_type=ValueError, 3779 msg='Dimensions 12 and 11 are not compatible'), 3780 dict( 3781 row_partitions=[ 3782 RowPartitionSpec(nvals=3, dtype=dtypes.int64), 3783 RowPartitionSpec(uniform_row_length=4, dtype=dtypes.int64), 3784 RowPartitionSpec(nrows=17, dtype=dtypes.int64), 3785 ], 3786 static_inner_shape=tensor_shape.TensorShape([20]), 3787 dtype=dtypes.int64, 3788 error_type=ValueError, 3789 msg='Dimensions 17 and 12 are not compatible'), 3790 ]) 3791 def test_constructor_raises(self, row_partitions, static_inner_shape, 3792 dtype, error_type, msg): 3793 # Note that self_merge is only idempotent if no data is partially present. 3794 with self.assertRaisesRegex(error_type, msg): 3795 dynamic_ragged_shape.DynamicRaggedShape.Spec( 3796 row_partitions=row_partitions, 3797 static_inner_shape=static_inner_shape, 3798 dtype=dtype) 3799 3800 @parameterized.parameters([ 3801 # Unknown rank 3802 dict( 3803 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3804 row_partitions=[], 3805 static_inner_shape=tensor_shape.TensorShape(None), 3806 dtype=dtypes.int64), 3807 expected=tensor_shape.TensorShape(None)), 3808 # Scalar 3809 dict( 3810 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3811 row_partitions=[], 3812 static_inner_shape=tensor_shape.TensorShape([]), 3813 dtype=dtypes.int64), 3814 expected=tensor_shape.TensorShape([])), 3815 # Vector 3816 dict( 3817 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3818 row_partitions=[], 3819 static_inner_shape=tensor_shape.TensorShape([3]), 3820 dtype=dtypes.int64), 3821 expected=tensor_shape.TensorShape([3])), 3822 # Dense 3823 dict( 3824 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3825 row_partitions=[], 3826 static_inner_shape=tensor_shape.TensorShape([3, 2, None]), 3827 dtype=dtypes.int64), 3828 expected=tensor_shape.TensorShape([3, 2, None])), 3829 # Ragged 3830 dict( 3831 original=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3832 row_partitions=[ 3833 RowPartitionSpec(nrows=6, 3834 nvals=None, 3835 uniform_row_length=10, 3836 dtype=dtypes.int64), 3837 RowPartitionSpec(nrows=60, 3838 nvals=None, 3839 uniform_row_length=None, 3840 dtype=dtypes.int64) 3841 ], 3842 static_inner_shape=tensor_shape.TensorShape([120]), 3843 dtype=dtypes.int64), 3844 expected=tensor_shape.TensorShape([6, 10, None])), 3845 3846 ]) 3847 def test_to_tensor_shape(self, original, expected): 3848 # Note that self_merge is only idempotent if no data is partially present. 3849 actual = original._to_tensor_shape() 3850 self.assertEqual(actual, expected) 3851 3852 @parameterized.parameters([ 3853 dict( 3854 a=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3855 row_partitions=[], 3856 static_inner_shape=tensor_shape.TensorShape([]), 3857 dtype=dtypes.int32), 3858 b=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3859 row_partitions=[], 3860 static_inner_shape=tensor_shape.TensorShape([]), 3861 dtype=dtypes.int32), 3862 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3863 row_partitions=[], 3864 static_inner_shape=tensor_shape.TensorShape([]), 3865 dtype=dtypes.int32)), 3866 dict( 3867 a=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3868 row_partitions=[], 3869 static_inner_shape=tensor_shape.TensorShape([3, None]), 3870 dtype=dtypes.int32), 3871 b=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3872 row_partitions=[], 3873 static_inner_shape=tensor_shape.TensorShape([None, 4]), 3874 dtype=dtypes.int32), 3875 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3876 row_partitions=[], 3877 static_inner_shape=tensor_shape.TensorShape([3, 4]), 3878 dtype=dtypes.int32)), 3879 dict( 3880 a=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3881 row_partitions=[ 3882 RowPartitionSpec( 3883 nrows=6, 3884 nvals=None, 3885 uniform_row_length=None, 3886 dtype=dtypes.int64) 3887 ], 3888 static_inner_shape=tensor_shape.TensorShape([None]), 3889 dtype=dtypes.int64), 3890 b=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3891 row_partitions=[ 3892 RowPartitionSpec( 3893 nrows=6, 3894 nvals=None, 3895 uniform_row_length=10, 3896 dtype=dtypes.int64) 3897 ], 3898 static_inner_shape=tensor_shape.TensorShape([None]), 3899 dtype=dtypes.int64), 3900 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3901 row_partitions=[ 3902 RowPartitionSpec( 3903 nrows=6, 3904 nvals=60, 3905 uniform_row_length=10, 3906 dtype=dtypes.int64) 3907 ], 3908 static_inner_shape=tensor_shape.TensorShape([60]), 3909 dtype=dtypes.int64)), 3910 dict( 3911 a=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3912 row_partitions=[ 3913 RowPartitionSpec( 3914 nrows=6, 3915 nvals=None, 3916 uniform_row_length=None, 3917 dtype=dtypes.int64) 3918 ], 3919 static_inner_shape=tensor_shape.TensorShape([None]), 3920 dtype=dtypes.int64), 3921 b=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3922 row_partitions=[], 3923 static_inner_shape=tensor_shape.TensorShape([None, 10]), 3924 dtype=dtypes.int64), 3925 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3926 row_partitions=[ 3927 RowPartitionSpec( 3928 nrows=6, 3929 nvals=60, 3930 uniform_row_length=10, 3931 dtype=dtypes.int64) 3932 ], 3933 static_inner_shape=tensor_shape.TensorShape([60]), 3934 dtype=dtypes.int64)) 3935 ]) 3936 def test_merge_with(self, 3937 a: DynamicRaggedShape.Spec, 3938 b: DynamicRaggedShape.Spec, 3939 expected: DynamicRaggedShape.Spec): 3940 actual = a._merge_with(b) 3941 actual_rev = b._merge_with(a) 3942 3943 self.assertDynamicRaggedShapeSpecEqual(actual, expected) 3944 self.assertDynamicRaggedShapeSpecEqual(actual_rev, expected) 3945 3946 @parameterized.parameters([ 3947 dict( 3948 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3949 row_partitions=[ 3950 RowPartitionSpec( 3951 nrows=6, 3952 nvals=3, 3953 uniform_row_length=None, 3954 dtype=dtypes.int64) 3955 ], 3956 static_inner_shape=tensor_shape.TensorShape([3]), 3957 dtype=dtypes.int64), 3958 batch_size=3, 3959 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3960 row_partitions=[ 3961 RowPartitionSpec( 3962 nrows=3, 3963 nvals=18, 3964 uniform_row_length=6, 3965 dtype=dtypes.int64), 3966 RowPartitionSpec( 3967 nrows=18, 3968 nvals=9, 3969 uniform_row_length=None, 3970 dtype=dtypes.int64) 3971 ], 3972 static_inner_shape=tensor_shape.TensorShape([9]), 3973 dtype=dtypes.int64)), 3974 dict( 3975 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3976 row_partitions=[ 3977 RowPartitionSpec( 3978 nrows=None, 3979 nvals=3, 3980 uniform_row_length=None, 3981 dtype=dtypes.int64) 3982 ], 3983 static_inner_shape=tensor_shape.TensorShape([3]), 3984 dtype=dtypes.int64), 3985 batch_size=3, 3986 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 3987 row_partitions=[ 3988 RowPartitionSpec( 3989 nrows=3, 3990 nvals=None, 3991 uniform_row_length=None, 3992 dtype=dtypes.int64), 3993 RowPartitionSpec( 3994 nrows=None, 3995 nvals=9, 3996 uniform_row_length=None, 3997 dtype=dtypes.int64) 3998 ], 3999 static_inner_shape=tensor_shape.TensorShape([9]), 4000 dtype=dtypes.int64)), 4001 dict( 4002 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4003 row_partitions=[ 4004 RowPartitionSpec( 4005 nrows=None, 4006 nvals=None, 4007 uniform_row_length=None, 4008 dtype=dtypes.int64) 4009 ], 4010 static_inner_shape=tensor_shape.TensorShape([None]), 4011 dtype=dtypes.int64), 4012 batch_size=3, 4013 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4014 row_partitions=[ 4015 RowPartitionSpec( 4016 nrows=3, 4017 nvals=None, 4018 uniform_row_length=None, 4019 dtype=dtypes.int64), 4020 RowPartitionSpec( 4021 nrows=None, 4022 nvals=None, 4023 uniform_row_length=None, 4024 dtype=dtypes.int64) 4025 ], 4026 static_inner_shape=tensor_shape.TensorShape([None]), 4027 dtype=dtypes.int64)), 4028 dict( 4029 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4030 row_partitions=[ 4031 RowPartitionSpec( 4032 nrows=None, 4033 nvals=None, 4034 uniform_row_length=None, 4035 dtype=dtypes.int64) 4036 ], 4037 static_inner_shape=tensor_shape.TensorShape(None), 4038 dtype=dtypes.int64), 4039 batch_size=3, 4040 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4041 row_partitions=[ 4042 RowPartitionSpec( 4043 nrows=3, 4044 nvals=None, 4045 uniform_row_length=None, 4046 dtype=dtypes.int64), 4047 RowPartitionSpec( 4048 nrows=None, 4049 nvals=None, 4050 uniform_row_length=None, 4051 dtype=dtypes.int64) 4052 ], 4053 static_inner_shape=tensor_shape.TensorShape(None), 4054 dtype=dtypes.int64)), 4055 dict( 4056 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4057 row_partitions=[ 4058 RowPartitionSpec( 4059 nrows=None, 4060 nvals=6, 4061 uniform_row_length=None, 4062 dtype=dtypes.int64) 4063 ], 4064 static_inner_shape=tensor_shape.TensorShape([6, 4]), 4065 dtype=dtypes.int64), 4066 batch_size=3, 4067 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4068 row_partitions=[ 4069 RowPartitionSpec( 4070 nrows=3, 4071 nvals=None, 4072 uniform_row_length=None, 4073 dtype=dtypes.int64), 4074 RowPartitionSpec( 4075 nrows=None, 4076 nvals=18, 4077 uniform_row_length=None, 4078 dtype=dtypes.int64) 4079 ], 4080 static_inner_shape=tensor_shape.TensorShape([18, 4]), 4081 dtype=dtypes.int64)), 4082 dict( 4083 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4084 row_partitions=[], 4085 static_inner_shape=tensor_shape.TensorShape(None), 4086 dtype=dtypes.int32), 4087 batch_size=3, 4088 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4089 row_partitions=[], 4090 static_inner_shape=tensor_shape.TensorShape(None), 4091 dtype=dtypes.int32)), 4092 dict( 4093 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4094 row_partitions=[], 4095 static_inner_shape=tensor_shape.TensorShape([8, 9]), 4096 dtype=dtypes.int32), 4097 batch_size=7, 4098 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4099 row_partitions=[], 4100 static_inner_shape=tensor_shape.TensorShape([7, 8, 9]), 4101 dtype=dtypes.int32)), 4102 ]) 4103 def test_batch(self, 4104 spec: DynamicRaggedShape.Spec, 4105 batch_size: int, 4106 expected: DynamicRaggedShape.Spec): 4107 encoder = dynamic_ragged_shape._DynamicRaggedShapeBatchEncoder() 4108 actual = encoder.batch(spec, batch_size) 4109 self.assertDynamicRaggedShapeSpecEqual(actual, expected) 4110 4111 @parameterized.parameters([ 4112 dict( 4113 spec=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4114 row_partitions=[ 4115 RowPartitionSpec( 4116 nrows=6, 4117 nvals=3, 4118 uniform_row_length=None, 4119 dtype=dtypes.int32)], 4120 static_inner_shape=tensor_shape.TensorShape([3]), 4121 dtype=dtypes.int32), 4122 expected=dynamic_ragged_shape.DynamicRaggedShape.Spec( 4123 row_partitions=[], 4124 static_inner_shape=tensor_shape.TensorShape([None]), 4125 dtype=dtypes.int32)) 4126 ]) 4127 def test_unbatch(self, spec: DynamicRaggedShape.Spec, 4128 expected: DynamicRaggedShape.Spec): 4129 encoder = dynamic_ragged_shape._DynamicRaggedShapeBatchEncoder() 4130 actual = encoder.unbatch(spec) 4131 self.assertDynamicRaggedShapeSpecEqual(actual, expected) 4132 4133 def test_repr(self): 4134 original = dynamic_ragged_shape.DynamicRaggedShape.Spec( 4135 row_partitions=[ 4136 RowPartitionSpec( 4137 nrows=6, 4138 nvals=None, 4139 uniform_row_length=None, 4140 dtype=dtypes.int64) 4141 ], 4142 static_inner_shape=tensor_shape.TensorShape([None]), 4143 dtype=dtypes.int64) 4144 representation = repr(original) 4145 static_inner_shape = tensor_shape.TensorShape([None]) 4146 expected = ('DynamicRaggedShape.Spec(' + 4147 'row_partitions=(RowPartitionSpec(' + 4148 'nrows=6, nvals=None, uniform_row_length=None, ' + 4149 'dtype=tf.int64),), ' + 4150 f'static_inner_shape={static_inner_shape!r}, ' + 4151 'dtype=tf.int64)') 4152 self.assertEqual(representation, expected) 4153 4154 @parameterized.parameters([ 4155 dict( 4156 lengths=[3, 4, 5], 4157 expected=DynamicRaggedShape.Spec( 4158 row_partitions=[], 4159 static_inner_shape=tensor_shape.TensorShape([3, 4, 5]), 4160 dtype=dtypes.int64)), 4161 dict( 4162 lengths=[2, (4, 1), 5], 4163 expected=DynamicRaggedShape.Spec( 4164 row_partitions=[RowPartitionSpec(nrows=2, nvals=5)], 4165 static_inner_shape=tensor_shape.TensorShape([5, 5]), 4166 dtype=dtypes.int64)), 4167 dict( 4168 lengths=[2, (4, 1), 5], 4169 dtype=dtypes.int32, 4170 expected=DynamicRaggedShape.Spec( 4171 row_partitions=[ 4172 RowPartitionSpec(nrows=2, nvals=5, dtype=dtypes.int32)], 4173 static_inner_shape=tensor_shape.TensorShape([5, 5]), 4174 dtype=dtypes.int32)), 4175 ]) 4176 def test_from_value(self, lengths, expected, dtype=None): 4177 original = DynamicRaggedShape.from_lengths(lengths) 4178 if dtype is not None: 4179 original = original.with_dtype(dtype) 4180 actual = dynamic_ragged_shape.DynamicRaggedShape.Spec.from_value(original) 4181 self.assertTensorShapeEqual(actual, expected) 4182 4183if __name__ == '__main__': 4184 googletest.main() 4185