1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Shapes & broadcasting for RaggedTensors.""" 16 17from tensorflow.python.framework import constant_op 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.framework import tensor_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops.ragged import ragged_array_ops 26from tensorflow.python.ops.ragged import ragged_config 27from tensorflow.python.ops.ragged import ragged_tensor 28from tensorflow.python.ops.ragged import ragged_util 29 30 31class RaggedTensorDynamicShape: 32 """A collection of tensors encoding the shape of a potentially ragged tensor. 33 34 Each `RaggedTensorDynamicShape` consists of an ordered list of dimension 35 sizes. There are two dimension types: 36 37 * "Uniform dimensions" are dimensions where all slices have the same 38 length. `RaggedTensorDynamicShape` records the size of each uniform 39 dimension using a single scalar integer. 40 41 * "Ragged dimensions" are dimensions whose slices may have different 42 lengths. `RaggedTensorDynamicShape` records the size of each ragged 43 dimension using an integer vector containing the slice lengths for all 44 the slices across that dimension. 45 46 Furthermore, there are two ways a dimension might be encoded: 47 48 * "Partitioned dimensions" are dimensions that are encoded using a 49 `RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned 50 dimension must be uniform, and the innermost partitioned dimension must 51 be ragged. 52 53 * "Inner dimensions" are dimensions that are encoded using a 54 `RaggedTensor`'s `flat_values`. Inner dimensions are always uniform. 55 56 The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes` 57 and `inner_dim_sizes`: 58 59 * `partitioned_dim_sizes` is a list of tensors (one for each partitioned 60 dimension). 61 62 * For uniform dimensions, the tensor is an integer scalar specifying the 63 size of all slices across that dimension. 64 * For ragged dimensions, the tensor is an integer vector specifying the 65 size of each slice across that dimension. 66 67 * `inner_dim_sizes` is a single integer vector, where each element 68 specifies the size of a single inner dimension. 69 70 Examples: 71 72 Tensor | Ragged | Partitioned Dim Sizes | Inner Dim 73 : Rank : : Sizes 74 ------------------------------ | ------ | ---------------------- | ---------- 75 `[[1, 2, 3], [4, 5, 6]]` | 0 | | `2, 3` 76 `[[1, 2], [], [3, 4, 5]]` | 1 | `3, (2, 0, 3)` | 77 `[[[1, 2], [3, 4]], [[5, 6]]]` | 1 | `2, (2, 1)` | 2 78 `[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` | 79 """ 80 81 def __init__(self, partitioned_dim_sizes, inner_dim_sizes, 82 dim_size_dtype=None): 83 """Creates a RaggedTensorDynamicShape. 84 85 Args: 86 partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for 87 each partitioned dimension. If dimension `d` is uniform, then 88 `partitioned_dim_sizes[d]` must be an integer scalar, specifying the 89 size of all slices across dimension `d`. If dimension `d` is ragged, 90 then `partitioned_dim_sizes[d]` must be an integer vector, specifying 91 the size of each slice across dimension `d`. 92 inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the 93 number of inner dimensions. `inner_dim_sizes[n]` is the size of all 94 slices across the `n`th inner dimension (which is the 95 `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor. 96 dim_size_dtype: dtype for dimension sizes. If not specified, then it 97 is chosen based on the dtypes of `partitioned_dim_sizes` and 98 `inner_dim_sizes`. 99 """ 100 assert isinstance(partitioned_dim_sizes, (list, tuple)) 101 102 with ops.name_scope(None, 'RaggedTensorDynamicShape', 103 (partitioned_dim_sizes, inner_dim_sizes)): 104 partitioned_dim_sizes = tuple( 105 ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i) 106 for (i, size) in enumerate(partitioned_dim_sizes)) 107 inner_dim_sizes = ops.convert_to_tensor( 108 inner_dim_sizes, name='inner_dim_sizes') 109 110 # Validate shapes. 111 if partitioned_dim_sizes: 112 for axis, dimension_size in enumerate(partitioned_dim_sizes): 113 if dimension_size.shape.ndims is None: 114 raise ValueError( 115 'rank of partitioned_dim_sizes[%d] is unknown' % axis) 116 dimension_size.shape.with_rank_at_most(1) 117 if partitioned_dim_sizes[0].shape.ndims == 1: 118 raise ValueError('outermost partitioned dimension must be uniform') 119 if partitioned_dim_sizes[-1].shape.ndims == 0: 120 raise ValueError('innermost partitioned dimension must be ragged') 121 inner_dim_sizes.shape.assert_has_rank(1) 122 123 # Convert dimension size tensors to a single dtype. 124 if dim_size_dtype is None: 125 dim_size_dtypes = set( 126 p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1) 127 if not dim_size_dtypes: 128 dim_size_dtype = dtypes.int64 129 elif len(dim_size_dtypes) == 1: 130 dim_size_dtype = dim_size_dtypes.pop() 131 else: 132 if not ragged_config.auto_cast_partition_dtype(): 133 raise ValueError('partitioned_dim_sizes must have matching dtypes') 134 dim_size_dtype = dtypes.int64 135 partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype) 136 for p in partitioned_dim_sizes) 137 inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype) 138 139 self._partitioned_dim_sizes = partitioned_dim_sizes 140 self._inner_dim_sizes = inner_dim_sizes 141 142 def __repr__(self): 143 return ('RaggedTensorDynamicShape' 144 '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' % 145 (self._partitioned_dim_sizes, self._inner_dim_sizes)) 146 147 @staticmethod 148 def from_dim_sizes(dim_sizes): 149 """Constructs a ragged shape from a list of dimension sizes. 150 151 This list contains a single tensor for each dimension, where the tensor 152 is a scalar if the dimension is uniform, or a vector if the dimension is 153 ragged. 154 155 Args: 156 dim_sizes: List of int32 or int64 scalars or vectors. 157 158 Returns: 159 A RaggedTensorDynamicShape. 160 """ 161 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes', 162 [dim_sizes]): 163 dim_sizes = tuple( 164 ops.convert_to_tensor(size, preferred_dtype=dtypes.int64, 165 name='dim_sizes') for size in dim_sizes) 166 # Split the dimensions into partitioned & inner dimensions. 167 inner_split = 0 168 for dim, dim_size in enumerate(dim_sizes): 169 if dim_size.shape.ndims == 1: 170 inner_split = dim + 1 171 elif dim_size.shape.ndims != 0: 172 raise ValueError('Each dim_size must be a scalar or a vector') 173 return RaggedTensorDynamicShape(dim_sizes[:inner_split], 174 dim_sizes[inner_split:]) 175 176 @classmethod 177 def from_tensor(cls, rt_input, dim_size_dtype=None): 178 """Constructs a ragged shape for a potentially ragged tensor.""" 179 with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]): 180 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 181 if not ragged_tensor.is_ragged(rt_input): 182 return cls([], array_ops.shape(rt_input), dim_size_dtype=dim_size_dtype) 183 else: 184 partitioned_dim_sizes = ( 185 (rt_input.nrows(),) + rt_input.nested_row_lengths()) 186 return RaggedTensorDynamicShape( 187 partitioned_dim_sizes, 188 array_ops.shape(rt_input.flat_values)[1:], 189 dim_size_dtype=dim_size_dtype) 190 191 def dimension_size(self, axis): 192 """Returns the size of slices across the specified dimension.""" 193 if not isinstance(axis, int): 194 raise TypeError('axis must be an integer') 195 partitioned_ndims = len(self._partitioned_dim_sizes) 196 if axis < partitioned_ndims: 197 return self._partitioned_dim_sizes[axis] 198 else: 199 return self._inner_dim_sizes[axis - partitioned_ndims] 200 201 def is_ragged(self, axis): 202 """Returns true if the indicated dimension is ragged.""" 203 if not isinstance(axis, int): 204 raise TypeError('axis must be an integer') 205 rank = self.rank 206 if axis < 0: 207 raise ValueError('Negative axis values are not supported') 208 elif rank is not None and axis >= rank: 209 raise ValueError('Expected axis=%s < rank=%s' % (axis, rank)) 210 else: 211 return (axis > 0 and axis < len(self._partitioned_dim_sizes) and 212 self._partitioned_dim_sizes[axis].shape.ndims == 1) 213 214 @property 215 def rank(self): 216 """The number of dimensions in this shape, or None if unknown.""" 217 inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) 218 if inner_ndims is None: 219 return None 220 else: 221 return len(self._partitioned_dim_sizes) + inner_ndims 222 223 @property 224 def partitioned_dim_sizes(self): 225 """The partitioned dimension sizes for this shape. 226 227 Returns: 228 A `list` of 0-D or 1-D integer `Tensor`. 229 """ 230 return self._partitioned_dim_sizes 231 232 @property 233 def inner_dim_sizes(self): 234 """The inner dimension sizes for this shape. 235 236 Returns: 237 A 1-D integer `Tensor`. 238 """ 239 return self._inner_dim_sizes 240 241 @property 242 def num_partitioned_dimensions(self): 243 """The number of partitioned dimensions in this shape.""" 244 return len(self._partitioned_dim_sizes) 245 246 @property 247 def num_inner_dimensions(self): 248 """The number of inner dimensions, or `None` if not statically known.""" 249 return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) 250 251 @property 252 def dim_size_dtype(self): 253 """DType used by this shape for dimension sizes.""" 254 return self._inner_dim_sizes.dtype 255 256 def broadcast_to_rank(self, rank): 257 """Adds leading size-1 dimensions to broadcast `self` to the given rank. 258 259 E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)` 260 is `[1, 1, 3, (D2), 4]`. 261 262 Args: 263 rank: The rank for the returned shape. 264 265 Returns: 266 A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions 267 have the same size as `self` and whose outer dimensions have size `1`. 268 269 Raises: 270 ValueError: If `self.rank` is unknown or greater than `rank`. 271 """ 272 if self.rank is None: 273 raise ValueError('Unable to broadcast: self.rank is unknown') 274 dims_to_add = rank - self.rank 275 if dims_to_add < 0: 276 raise ValueError('Unable to broadcast: rank=%d must be greater than ' 277 'self.rank=%d.' % (rank, self.rank)) 278 elif dims_to_add == 0: 279 return self 280 elif self._partitioned_dim_sizes: 281 partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes 282 return RaggedTensorDynamicShape(partitioned_dims, self.inner_dim_sizes, 283 self.dim_size_dtype) 284 else: 285 inner_dims = array_ops.concat( 286 [array_ops.ones([dims_to_add], self.dim_size_dtype), 287 self.inner_dim_sizes], 288 axis=0) 289 return RaggedTensorDynamicShape([], inner_dims, self.dim_size_dtype) 290 291 def broadcast_dimension(self, axis, lengths): 292 """Returns a shape that is broadcast-compatible with self & lengths. 293 294 * If dimension[axis] is uniform and lengths is a scalar, the check 295 that either lengths==1 or axis==1 or lengths==axis, and tile 296 dimension[axis] with tf.where(lengths==axis, 1, axis) repeats. 297 298 * If dimension[axis] is uniform and lengths is a vector, then check 299 that dimension[axis]==1, and raggedly tile dimension[axis] with 300 lengths repeats. (we can skip tiling if we statically know that 301 slice_lengths == 1??) 302 303 * If dimension[axis] is ragged and lengths is a scalar, then check 304 that lengths==1. 305 306 * If dimension[axis] is ragged and lengths is a vector, then check 307 that self.dimension_size(axis) == lengths. 308 309 Args: 310 axis: `int`. The dimension to broadcast. 311 lengths: 0-D or 1-D integer `Tensor`. 312 313 Returns: 314 A `RaggedTensorDynamicShape`. 315 """ 316 lengths = ragged_util.convert_to_int_tensor( 317 lengths, name='lengths', dtype=self.dim_size_dtype) 318 # Check whether lengths is a scalar (for uniform dimensions) or 319 # vector (for ragged dimensions). 320 if lengths.shape.ndims is None: 321 raise ValueError('lengths must have a known rank.') 322 elif lengths.shape.ndims > 1: 323 raise ValueError('lengths must be a scalar or vector') 324 else: 325 lengths_is_scalar = (lengths.shape.ndims == 0) 326 327 # Verify that the shapes are compatible. 328 if self.is_ragged(axis): 329 if lengths_is_scalar: 330 condition = math_ops.equal(lengths, 1) 331 else: 332 condition = math_ops.reduce_all( 333 math_ops.equal(lengths, self.dimension_size(axis))) 334 else: 335 axis_dim_size = self.dimension_size(axis) 336 if lengths_is_scalar: 337 condition = ( 338 math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1) 339 | math_ops.equal(axis_dim_size, lengths)) 340 else: 341 condition = math_ops.equal(axis_dim_size, 1) 342 broadcast_err = [ 343 'Unable to broadcast: dimension size mismatch in dimension', axis, 344 'lengths=', lengths, 'dim_size=', 345 self.dimension_size(axis) 346 ] 347 broadcast_check = control_flow_ops.Assert( 348 condition, data=broadcast_err, summarize=10) 349 350 with ops.control_dependencies([broadcast_check]): 351 # Partitioned dimensions: 352 if axis < self.num_partitioned_dimensions: 353 if self.is_ragged(axis): 354 # Use an identity op to make sure the check actually gets run. 355 return RaggedTensorDynamicShape( 356 self._partitioned_dim_sizes, 357 array_ops.identity(self.inner_dim_sizes), self.dim_size_dtype) 358 else: 359 return self._broadcast_uniform_partitioned_dimension(axis, lengths) 360 361 # Inner dimensions: 362 else: 363 if lengths_is_scalar: 364 return self._broadcast_inner_dimension_to_uniform(axis, lengths) 365 else: 366 if axis == 0: 367 raise ValueError('Unable to broadcast: ' 368 'outermost dimension must be uniform.') 369 return self._broadcast_inner_dimension_to_ragged(axis, lengths) 370 371 def num_slices_in_dimension(self, axis): 372 """Returns the total number of slices across the indicated dimension.""" 373 if axis < 0: 374 return constant_op.constant(1, dtype=self.dim_size_dtype) 375 elif self.is_ragged(axis): 376 return math_ops.reduce_sum(self._partitioned_dim_sizes[axis]) 377 else: 378 return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1) 379 380 def _broadcast_uniform_partitioned_dimension(self, axis, lengths): 381 """Broadcasts the partitioned dimension `axis` to match `lengths`.""" 382 axis_dim_size = self.dimension_size(axis) 383 partitioned_sizes = list(self._partitioned_dim_sizes[:axis]) 384 385 if lengths.shape.ndims == 0: 386 lengths = array_ops.where( 387 math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size) 388 repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1) 389 splits = array_ops.stack([0, self.num_slices_in_dimension(axis)]) 390 else: 391 splits = math_ops.range( 392 array_ops.size(lengths, out_type=self.dim_size_dtype) + 1) 393 repeats = lengths 394 395 partitioned_sizes.append(lengths) 396 397 for dim_size in self._partitioned_dim_sizes[axis + 1:]: 398 if dim_size.shape.ndims == 0: 399 partitioned_sizes.append(dim_size) 400 splits *= dim_size 401 else: 402 partitioned_sizes.append( 403 ragged_util.repeat_ranges(dim_size, splits, repeats)) 404 splits = array_ops.gather( 405 ragged_util.lengths_to_splits(dim_size), splits) 406 inner_sizes = self._inner_dim_sizes 407 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes, 408 self.dim_size_dtype) 409 410 def _broadcast_inner_dimension_to_uniform(self, axis, length): 411 """Broadcasts the inner dimension `axis` to match `lengths`.""" 412 dim_size = self.dimension_size(axis) 413 axis_in_inner_dims = axis - self.num_partitioned_dimensions 414 partitioned_sizes = self._partitioned_dim_sizes 415 inner_sizes = array_ops.concat([ 416 self._inner_dim_sizes[:axis_in_inner_dims], 417 [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)], 418 self._inner_dim_sizes[axis_in_inner_dims + 1:] 419 ], 420 axis=0) 421 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes, 422 self.dim_size_dtype) 423 424 def _broadcast_inner_dimension_to_ragged(self, axis, lengths): 425 axis_in_inner_dims = axis - self.num_partitioned_dimensions 426 partitioned_sizes = ( 427 self._partitioned_dim_sizes + tuple([ 428 self._inner_dim_sizes[i] for i in range(axis_in_inner_dims) 429 ]) + (lengths,)) 430 inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:] 431 return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) 432 433 def with_dim_size_dtype(self, dtype): 434 if dtype not in (dtypes.int32, dtypes.int64): 435 raise ValueError('dtype must be int32 or int64') 436 if self.dim_size_dtype == dtype: 437 return self 438 return RaggedTensorDynamicShape( 439 [math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes], 440 math_ops.cast(self._inner_dim_sizes, dtype)) 441 442 443def broadcast_dynamic_shape(shape_x, shape_y): 444 """Returns the shape formed by broadcasting two shapes to be compatible. 445 446 Args: 447 shape_x: A `RaggedTensorDynamicShape` 448 shape_y: A `RaggedTensorDynamicShape` 449 450 Returns: 451 A `RaggedTensorDynamicShape`. 452 Raises: 453 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible. 454 """ 455 if not isinstance(shape_x, RaggedTensorDynamicShape): 456 raise TypeError('shape_x must be a RaggedTensorDynamicShape') 457 if not isinstance(shape_y, RaggedTensorDynamicShape): 458 raise TypeError('shape_y must be a RaggedTensorDynamicShape') 459 460 # Broadcast both shapes to have the same rank. 461 if shape_x.rank is None or shape_y.rank is None: 462 raise ValueError('Unable to broadcast: unknown rank') 463 broadcast_rank = max(shape_x.rank, shape_y.rank) 464 shape_x = shape_x.broadcast_to_rank(broadcast_rank) 465 shape_y = shape_y.broadcast_to_rank(broadcast_rank) 466 467 # Broadcast dimensions one at a time, starting from the outermost dimension. 468 for axis in range(broadcast_rank): 469 shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis)) 470 shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis)) 471 472 return shape_x 473 474 475def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True): 476 """Broadcasts a potentially ragged tensor to a ragged shape. 477 478 Tiles `rt_input` as necessary to match the given shape. 479 480 Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`. 481 482 Args: 483 rt_input: The potentially ragged tensor to broadcast. 484 shape: A `RaggedTensorDynamicShape` 485 broadcast_inner_dimensions: If false, then inner dimensions will not be 486 tiled. 487 488 Returns: 489 A potentially ragged tensor whose values are taken from 490 `rt_input`, and whose shape matches `shape`. 491 """ 492 if not isinstance(shape, RaggedTensorDynamicShape): 493 raise TypeError('shape must be a RaggedTensorDynamicShape') 494 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 495 496 # Broadcasting to a uniform shape. 497 if shape.num_partitioned_dimensions == 0: 498 return _broadcast_to_uniform_shape(rt_input, shape, 499 broadcast_inner_dimensions) 500 else: 501 return _broadcast_to_ragged_shape(rt_input, shape, 502 broadcast_inner_dimensions) 503 504 505def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions): 506 """Broadcasts rt_input to the uniform shape `shape`.""" 507 if isinstance(rt_input, ragged_tensor.RaggedTensor): 508 raise ValueError('Incompatible with shape: ragged rank mismatch') 509 if broadcast_inner_dimensions: 510 return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes) 511 else: 512 return rt_input 513 514 515def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): 516 """Broadcasts rt_input to the ragged shape `dst_shape`.""" 517 # Check that rt_input and dst_shape have the same row_splits dtype. 518 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and 519 rt_input.row_splits.dtype != dst_shape.dim_size_dtype): 520 if not ragged_config.auto_cast_partition_dtype(): 521 raise ValueError('rt_input and dst_shape have different row_split ' 522 'dtypes; use RaggedTensor.with_row_splits_dtype() or ' 523 'RaggedTensorDynamicShape.with_dim_size_dtype() to ' 524 'convert to a compatible dtype.') 525 rt_input = rt_input.with_row_splits_dtype(dtypes.int64) 526 dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64) 527 528 # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's 529 if rt_input.shape.ndims is None or dst_shape.rank is None: 530 raise ValueError('Unable to broadcast: unknown rank') 531 if rt_input.shape.ndims > dst_shape.rank: 532 raise ValueError('Incompatible with shape: rank mismatch') 533 if (isinstance(rt_input, ragged_tensor.RaggedTensor) and 534 rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): 535 raise ValueError('Incompatible with shape: ragged rank mismatch') 536 537 src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) 538 src_shape = src_shape.broadcast_to_rank(dst_shape.rank) 539 540 # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. 541 if dst_shape.rank > rt_input.shape.ndims: 542 if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: 543 rt_input = array_ops.reshape( 544 rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) 545 for _ in range(dst_shape.rank - rt_input.shape.ndims): 546 if ragged_tensor.is_ragged(rt_input): 547 nrows = rt_input.nrows() 548 else: 549 nrows = array_ops.shape(rt_input, 550 out_type=dst_shape.dim_size_dtype)[0] 551 rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows], 552 validate=False) 553 554 # Add ragged dimensions to match dst_shape. 555 if ragged_tensor.is_ragged(rt_input): 556 inner_rank_diff = ( 557 rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) 558 if inner_rank_diff > 0: 559 rt_input = rt_input.with_flat_values( 560 ragged_tensor.RaggedTensor.from_tensor( 561 rt_input.flat_values, ragged_rank=inner_rank_diff, 562 row_splits_dtype=dst_shape.dim_size_dtype)) 563 else: 564 rt_input = ragged_tensor.RaggedTensor.from_tensor( 565 rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1, 566 row_splits_dtype=dst_shape.dim_size_dtype) 567 568 # Do broadcasting for any dimensions that will remain uniform. We can do 569 # these all at once, since they're independent of one another. 570 multiples = [1] * dst_shape.rank 571 for axis in range(dst_shape.num_partitioned_dimensions): 572 if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): 573 src_size = src_shape.dimension_size(axis) 574 dst_size = dst_shape.dimension_size(axis) 575 if ((tensor_util.constant_value(src_size) in (1, None)) and 576 (tensor_util.constant_value(dst_size) != 1)): 577 multiples[axis] = array_ops.where( 578 math_ops.equal(src_size, 1), dst_size, 1) 579 if not all(isinstance(v, int) and v == 1 for v in multiples): 580 multiples = array_ops.stack(multiples, axis=0) 581 rt_input = ragged_array_ops.tile(rt_input, multiples) 582 583 if broadcast_inner_dimensions: 584 new_shape = array_ops.broadcast_dynamic_shape( 585 array_ops.shape( 586 rt_input.flat_values, out_type=dst_shape.dim_size_dtype), 587 array_ops.concat([[1], dst_shape.inner_dim_sizes], axis=0)) 588 rt_input = rt_input.with_flat_values( 589 array_ops.broadcast_to(rt_input.flat_values, new_shape)) 590 591 # Do broadcasting for dimensions that become ragged. We must do these from 592 # outermost to innermost. 593 for axis in range(dst_shape.num_partitioned_dimensions): 594 if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): 595 dst_size = dst_shape.dimension_size(axis) 596 rt_input = _ragged_tile_axis(rt_input, axis, dst_size, 597 dst_shape.dim_size_dtype) 598 599 return rt_input 600 601 602def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype): 603 """Tile a dimension of a RaggedTensor to match a ragged shape.""" 604 assert axis > 0 # Outermost dimension may not be ragged. 605 606 if not ragged_tensor.is_ragged(rt_input): 607 rt_input = ragged_tensor.RaggedTensor.from_tensor( 608 rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype) 609 610 if axis > 1: 611 return rt_input.with_values( 612 _ragged_tile_axis(rt_input.values, axis - 1, repeats, 613 row_splits_dtype)) 614 else: 615 src_row_splits = rt_input.nested_row_splits 616 src_row_lengths = rt_input.nested_row_lengths() 617 splits = src_row_splits[0] 618 619 dst_row_lengths = [repeats] 620 for i in range(1, len(src_row_lengths)): 621 dst_row_lengths.append( 622 ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats)) 623 splits = array_ops.gather(src_row_splits[i], splits) 624 dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits, 625 repeats) 626 return ragged_tensor.RaggedTensor.from_nested_row_lengths( 627 dst_values, dst_row_lengths, validate=False) 628