xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/ragged/ragged_tensor_shape.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shapes & broadcasting for RaggedTensors."""
16
17from tensorflow.python.framework import constant_op
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.ragged import ragged_array_ops
26from tensorflow.python.ops.ragged import ragged_config
27from tensorflow.python.ops.ragged import ragged_tensor
28from tensorflow.python.ops.ragged import ragged_util
29
30
31class RaggedTensorDynamicShape:
32  """A collection of tensors encoding the shape of a potentially ragged tensor.
33
34  Each `RaggedTensorDynamicShape` consists of an ordered list of dimension
35  sizes.  There are two dimension types:
36
37    * "Uniform dimensions" are dimensions where all slices have the same
38      length.  `RaggedTensorDynamicShape` records the size of each uniform
39      dimension using a single scalar integer.
40
41    * "Ragged dimensions" are dimensions whose slices may have different
42      lengths.  `RaggedTensorDynamicShape` records the size of each ragged
43      dimension using an integer vector containing the slice lengths for all
44      the slices across that dimension.
45
46  Furthermore, there are two ways a dimension might be encoded:
47
48    * "Partitioned dimensions" are dimensions that are encoded using a
49      `RaggedTensor`'s `nested_row_splits`.  The outermostmost partitioned
50      dimension must be uniform, and the innermost partitioned dimension must
51      be ragged.
52
53    * "Inner dimensions" are dimensions that are encoded using a
54      `RaggedTensor`'s `flat_values`.  Inner dimensions are always uniform.
55
56  The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes`
57  and `inner_dim_sizes`:
58
59    * `partitioned_dim_sizes` is a list of tensors (one for each partitioned
60      dimension).
61
62      * For uniform dimensions, the tensor is an integer scalar specifying the
63        size of all slices across that dimension.
64      * For ragged dimensions, the tensor is an integer vector specifying the
65        size of each slice across that dimension.
66
67    * `inner_dim_sizes` is a single integer vector, where each element
68      specifies the size of a single inner dimension.
69
70  Examples:
71
72  Tensor                         | Ragged | Partitioned Dim Sizes  | Inner Dim
73                                 : Rank   :                        : Sizes
74  ------------------------------ | ------ | ---------------------- | ----------
75  `[[1, 2, 3], [4, 5, 6]]`       |      0 |                        | `2, 3`
76  `[[1, 2], [], [3, 4, 5]]`      |      1 | `3, (2, 0, 3)`         |
77  `[[[1, 2], [3, 4]], [[5, 6]]]` |      1 | `2, (2, 1)`            | 2
78  `[[[1, 2], [3]], [[4, 5]]]`    |      2 | `2, (2, 1), (2, 1, 2)` |
79  """
80
81  def __init__(self, partitioned_dim_sizes, inner_dim_sizes,
82               dim_size_dtype=None):
83    """Creates a RaggedTensorDynamicShape.
84
85    Args:
86      partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for
87        each partitioned dimension.  If dimension `d` is uniform, then
88        `partitioned_dim_sizes[d]` must be an integer scalar, specifying the
89        size of all slices across dimension `d`.  If dimension `d` is ragged,
90        then `partitioned_dim_sizes[d]` must be an integer vector, specifying
91        the size of each slice across dimension `d`.
92      inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the
93        number of inner dimensions.  `inner_dim_sizes[n]` is the size of all
94        slices across the `n`th inner dimension (which is the
95        `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor.
96      dim_size_dtype: dtype for dimension sizes.  If not specified, then it
97        is chosen based on the dtypes of `partitioned_dim_sizes` and
98        `inner_dim_sizes`.
99    """
100    assert isinstance(partitioned_dim_sizes, (list, tuple))
101
102    with ops.name_scope(None, 'RaggedTensorDynamicShape',
103                        (partitioned_dim_sizes, inner_dim_sizes)):
104      partitioned_dim_sizes = tuple(
105          ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i)
106          for (i, size) in enumerate(partitioned_dim_sizes))
107      inner_dim_sizes = ops.convert_to_tensor(
108          inner_dim_sizes, name='inner_dim_sizes')
109
110      # Validate shapes.
111      if partitioned_dim_sizes:
112        for axis, dimension_size in enumerate(partitioned_dim_sizes):
113          if dimension_size.shape.ndims is None:
114            raise ValueError(
115                'rank of partitioned_dim_sizes[%d] is unknown' % axis)
116          dimension_size.shape.with_rank_at_most(1)
117        if partitioned_dim_sizes[0].shape.ndims == 1:
118          raise ValueError('outermost partitioned dimension must be uniform')
119        if partitioned_dim_sizes[-1].shape.ndims == 0:
120          raise ValueError('innermost partitioned dimension must be ragged')
121      inner_dim_sizes.shape.assert_has_rank(1)
122
123      # Convert dimension size tensors to a single dtype.
124      if dim_size_dtype is None:
125        dim_size_dtypes = set(
126            p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1)
127        if not dim_size_dtypes:
128          dim_size_dtype = dtypes.int64
129        elif len(dim_size_dtypes) == 1:
130          dim_size_dtype = dim_size_dtypes.pop()
131        else:
132          if not ragged_config.auto_cast_partition_dtype():
133            raise ValueError('partitioned_dim_sizes must have matching dtypes')
134          dim_size_dtype = dtypes.int64
135      partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype)
136                                    for p in partitioned_dim_sizes)
137      inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype)
138
139      self._partitioned_dim_sizes = partitioned_dim_sizes
140      self._inner_dim_sizes = inner_dim_sizes
141
142  def __repr__(self):
143    return ('RaggedTensorDynamicShape'
144            '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' %
145            (self._partitioned_dim_sizes, self._inner_dim_sizes))
146
147  @staticmethod
148  def from_dim_sizes(dim_sizes):
149    """Constructs a ragged shape from a list of dimension sizes.
150
151    This list contains a single tensor for each dimension, where the tensor
152    is a scalar if the dimension is uniform, or a vector if the dimension is
153    ragged.
154
155    Args:
156      dim_sizes: List of int32 or int64 scalars or vectors.
157
158    Returns:
159      A RaggedTensorDynamicShape.
160    """
161    with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes',
162                        [dim_sizes]):
163      dim_sizes = tuple(
164          ops.convert_to_tensor(size, preferred_dtype=dtypes.int64,
165                                name='dim_sizes') for size in dim_sizes)
166      # Split the dimensions into partitioned & inner dimensions.
167      inner_split = 0
168      for dim, dim_size in enumerate(dim_sizes):
169        if dim_size.shape.ndims == 1:
170          inner_split = dim + 1
171        elif dim_size.shape.ndims != 0:
172          raise ValueError('Each dim_size must be a scalar or a vector')
173      return RaggedTensorDynamicShape(dim_sizes[:inner_split],
174                                      dim_sizes[inner_split:])
175
176  @classmethod
177  def from_tensor(cls, rt_input, dim_size_dtype=None):
178    """Constructs a ragged shape for a potentially ragged tensor."""
179    with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
180      rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
181      if not ragged_tensor.is_ragged(rt_input):
182        return cls([], array_ops.shape(rt_input), dim_size_dtype=dim_size_dtype)
183      else:
184        partitioned_dim_sizes = (
185            (rt_input.nrows(),) + rt_input.nested_row_lengths())
186        return RaggedTensorDynamicShape(
187            partitioned_dim_sizes,
188            array_ops.shape(rt_input.flat_values)[1:],
189            dim_size_dtype=dim_size_dtype)
190
191  def dimension_size(self, axis):
192    """Returns the size of slices across the specified dimension."""
193    if not isinstance(axis, int):
194      raise TypeError('axis must be an integer')
195    partitioned_ndims = len(self._partitioned_dim_sizes)
196    if axis < partitioned_ndims:
197      return self._partitioned_dim_sizes[axis]
198    else:
199      return self._inner_dim_sizes[axis - partitioned_ndims]
200
201  def is_ragged(self, axis):
202    """Returns true if the indicated dimension is ragged."""
203    if not isinstance(axis, int):
204      raise TypeError('axis must be an integer')
205    rank = self.rank
206    if axis < 0:
207      raise ValueError('Negative axis values are not supported')
208    elif rank is not None and axis >= rank:
209      raise ValueError('Expected axis=%s < rank=%s' % (axis, rank))
210    else:
211      return (axis > 0 and axis < len(self._partitioned_dim_sizes) and
212              self._partitioned_dim_sizes[axis].shape.ndims == 1)
213
214  @property
215  def rank(self):
216    """The number of dimensions in this shape, or None if unknown."""
217    inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
218    if inner_ndims is None:
219      return None
220    else:
221      return len(self._partitioned_dim_sizes) + inner_ndims
222
223  @property
224  def partitioned_dim_sizes(self):
225    """The partitioned dimension sizes for this shape.
226
227    Returns:
228      A `list` of 0-D or 1-D integer `Tensor`.
229    """
230    return self._partitioned_dim_sizes
231
232  @property
233  def inner_dim_sizes(self):
234    """The inner dimension sizes for this shape.
235
236    Returns:
237      A 1-D integer `Tensor`.
238    """
239    return self._inner_dim_sizes
240
241  @property
242  def num_partitioned_dimensions(self):
243    """The number of partitioned dimensions in this shape."""
244    return len(self._partitioned_dim_sizes)
245
246  @property
247  def num_inner_dimensions(self):
248    """The number of inner dimensions, or `None` if not statically known."""
249    return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0])
250
251  @property
252  def dim_size_dtype(self):
253    """DType used by this shape for dimension sizes."""
254    return self._inner_dim_sizes.dtype
255
256  def broadcast_to_rank(self, rank):
257    """Adds leading size-1 dimensions to broadcast `self` to the given rank.
258
259    E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)`
260    is `[1, 1, 3, (D2), 4]`.
261
262    Args:
263      rank: The rank for the returned shape.
264
265    Returns:
266      A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions
267      have the same size as `self` and whose outer dimensions have size `1`.
268
269    Raises:
270      ValueError: If `self.rank` is unknown or greater than `rank`.
271    """
272    if self.rank is None:
273      raise ValueError('Unable to broadcast: self.rank is unknown')
274    dims_to_add = rank - self.rank
275    if dims_to_add < 0:
276      raise ValueError('Unable to broadcast: rank=%d must be greater than '
277                       'self.rank=%d.' % (rank, self.rank))
278    elif dims_to_add == 0:
279      return self
280    elif self._partitioned_dim_sizes:
281      partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes
282      return RaggedTensorDynamicShape(partitioned_dims, self.inner_dim_sizes,
283                                      self.dim_size_dtype)
284    else:
285      inner_dims = array_ops.concat(
286          [array_ops.ones([dims_to_add], self.dim_size_dtype),
287           self.inner_dim_sizes],
288          axis=0)
289      return RaggedTensorDynamicShape([], inner_dims, self.dim_size_dtype)
290
291  def broadcast_dimension(self, axis, lengths):
292    """Returns a shape that is broadcast-compatible with self & lengths.
293
294    * If dimension[axis] is uniform and lengths is a scalar, the check
295      that either lengths==1 or axis==1 or lengths==axis, and tile
296      dimension[axis] with tf.where(lengths==axis, 1, axis) repeats.
297
298    * If dimension[axis] is uniform and lengths is a vector, then check
299      that dimension[axis]==1, and raggedly tile dimension[axis] with
300      lengths repeats.  (we can skip tiling if we statically know that
301      slice_lengths == 1??)
302
303    * If dimension[axis] is ragged and lengths is a scalar, then check
304      that lengths==1.
305
306    * If dimension[axis] is ragged and lengths is a vector, then check
307      that self.dimension_size(axis) == lengths.
308
309    Args:
310      axis: `int`.  The dimension to broadcast.
311      lengths: 0-D or 1-D integer `Tensor`.
312
313    Returns:
314      A `RaggedTensorDynamicShape`.
315    """
316    lengths = ragged_util.convert_to_int_tensor(
317        lengths, name='lengths', dtype=self.dim_size_dtype)
318    # Check whether lengths is a scalar (for uniform dimensions) or
319    # vector (for ragged dimensions).
320    if lengths.shape.ndims is None:
321      raise ValueError('lengths must have a known rank.')
322    elif lengths.shape.ndims > 1:
323      raise ValueError('lengths must be a scalar or vector')
324    else:
325      lengths_is_scalar = (lengths.shape.ndims == 0)
326
327    # Verify that the shapes are compatible.
328    if self.is_ragged(axis):
329      if lengths_is_scalar:
330        condition = math_ops.equal(lengths, 1)
331      else:
332        condition = math_ops.reduce_all(
333            math_ops.equal(lengths, self.dimension_size(axis)))
334    else:
335      axis_dim_size = self.dimension_size(axis)
336      if lengths_is_scalar:
337        condition = (
338            math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1)
339            | math_ops.equal(axis_dim_size, lengths))
340      else:
341        condition = math_ops.equal(axis_dim_size, 1)
342    broadcast_err = [
343        'Unable to broadcast: dimension size mismatch in dimension', axis,
344        'lengths=', lengths, 'dim_size=',
345        self.dimension_size(axis)
346    ]
347    broadcast_check = control_flow_ops.Assert(
348        condition, data=broadcast_err, summarize=10)
349
350    with ops.control_dependencies([broadcast_check]):
351      # Partitioned dimensions:
352      if axis < self.num_partitioned_dimensions:
353        if self.is_ragged(axis):
354          # Use an identity op to make sure the check actually gets run.
355          return RaggedTensorDynamicShape(
356              self._partitioned_dim_sizes,
357              array_ops.identity(self.inner_dim_sizes), self.dim_size_dtype)
358        else:
359          return self._broadcast_uniform_partitioned_dimension(axis, lengths)
360
361      # Inner dimensions:
362      else:
363        if lengths_is_scalar:
364          return self._broadcast_inner_dimension_to_uniform(axis, lengths)
365        else:
366          if axis == 0:
367            raise ValueError('Unable to broadcast: '
368                             'outermost dimension must be uniform.')
369          return self._broadcast_inner_dimension_to_ragged(axis, lengths)
370
371  def num_slices_in_dimension(self, axis):
372    """Returns the total number of slices across the indicated dimension."""
373    if axis < 0:
374      return constant_op.constant(1, dtype=self.dim_size_dtype)
375    elif self.is_ragged(axis):
376      return math_ops.reduce_sum(self._partitioned_dim_sizes[axis])
377    else:
378      return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1)
379
380  def _broadcast_uniform_partitioned_dimension(self, axis, lengths):
381    """Broadcasts the partitioned dimension `axis` to match `lengths`."""
382    axis_dim_size = self.dimension_size(axis)
383    partitioned_sizes = list(self._partitioned_dim_sizes[:axis])
384
385    if lengths.shape.ndims == 0:
386      lengths = array_ops.where(
387          math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size)
388      repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1)
389      splits = array_ops.stack([0, self.num_slices_in_dimension(axis)])
390    else:
391      splits = math_ops.range(
392          array_ops.size(lengths, out_type=self.dim_size_dtype) + 1)
393      repeats = lengths
394
395    partitioned_sizes.append(lengths)
396
397    for dim_size in self._partitioned_dim_sizes[axis + 1:]:
398      if dim_size.shape.ndims == 0:
399        partitioned_sizes.append(dim_size)
400        splits *= dim_size
401      else:
402        partitioned_sizes.append(
403            ragged_util.repeat_ranges(dim_size, splits, repeats))
404        splits = array_ops.gather(
405            ragged_util.lengths_to_splits(dim_size), splits)
406    inner_sizes = self._inner_dim_sizes
407    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes,
408                                    self.dim_size_dtype)
409
410  def _broadcast_inner_dimension_to_uniform(self, axis, length):
411    """Broadcasts the inner dimension `axis` to match `lengths`."""
412    dim_size = self.dimension_size(axis)
413    axis_in_inner_dims = axis - self.num_partitioned_dimensions
414    partitioned_sizes = self._partitioned_dim_sizes
415    inner_sizes = array_ops.concat([
416        self._inner_dim_sizes[:axis_in_inner_dims],
417        [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)],
418        self._inner_dim_sizes[axis_in_inner_dims + 1:]
419    ],
420                                   axis=0)
421    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes,
422                                    self.dim_size_dtype)
423
424  def _broadcast_inner_dimension_to_ragged(self, axis, lengths):
425    axis_in_inner_dims = axis - self.num_partitioned_dimensions
426    partitioned_sizes = (
427        self._partitioned_dim_sizes + tuple([
428            self._inner_dim_sizes[i] for i in range(axis_in_inner_dims)
429        ]) + (lengths,))
430    inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:]
431    return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes)
432
433  def with_dim_size_dtype(self, dtype):
434    if dtype not in (dtypes.int32, dtypes.int64):
435      raise ValueError('dtype must be int32 or int64')
436    if self.dim_size_dtype == dtype:
437      return self
438    return RaggedTensorDynamicShape(
439        [math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes],
440        math_ops.cast(self._inner_dim_sizes, dtype))
441
442
443def broadcast_dynamic_shape(shape_x, shape_y):
444  """Returns the shape formed by broadcasting two shapes to be compatible.
445
446  Args:
447    shape_x: A `RaggedTensorDynamicShape`
448    shape_y: A `RaggedTensorDynamicShape`
449
450  Returns:
451    A `RaggedTensorDynamicShape`.
452  Raises:
453    ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
454  """
455  if not isinstance(shape_x, RaggedTensorDynamicShape):
456    raise TypeError('shape_x must be a RaggedTensorDynamicShape')
457  if not isinstance(shape_y, RaggedTensorDynamicShape):
458    raise TypeError('shape_y must be a RaggedTensorDynamicShape')
459
460  # Broadcast both shapes to have the same rank.
461  if shape_x.rank is None or shape_y.rank is None:
462    raise ValueError('Unable to broadcast: unknown rank')
463  broadcast_rank = max(shape_x.rank, shape_y.rank)
464  shape_x = shape_x.broadcast_to_rank(broadcast_rank)
465  shape_y = shape_y.broadcast_to_rank(broadcast_rank)
466
467  # Broadcast dimensions one at a time, starting from the outermost dimension.
468  for axis in range(broadcast_rank):
469    shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis))
470    shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis))
471
472  return shape_x
473
474
475def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True):
476  """Broadcasts a potentially ragged tensor to a ragged shape.
477
478  Tiles `rt_input` as necessary to match the given shape.
479
480  Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
481
482  Args:
483    rt_input: The potentially ragged tensor to broadcast.
484    shape: A `RaggedTensorDynamicShape`
485    broadcast_inner_dimensions: If false, then inner dimensions will not be
486      tiled.
487
488  Returns:
489    A potentially ragged tensor whose values are taken from
490    `rt_input`, and whose shape matches `shape`.
491  """
492  if not isinstance(shape, RaggedTensorDynamicShape):
493    raise TypeError('shape must be a RaggedTensorDynamicShape')
494  rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
495
496  # Broadcasting to a uniform shape.
497  if shape.num_partitioned_dimensions == 0:
498    return _broadcast_to_uniform_shape(rt_input, shape,
499                                       broadcast_inner_dimensions)
500  else:
501    return _broadcast_to_ragged_shape(rt_input, shape,
502                                      broadcast_inner_dimensions)
503
504
505def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions):
506  """Broadcasts rt_input to the uniform shape `shape`."""
507  if isinstance(rt_input, ragged_tensor.RaggedTensor):
508    raise ValueError('Incompatible with shape: ragged rank mismatch')
509  if broadcast_inner_dimensions:
510    return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes)
511  else:
512    return rt_input
513
514
515def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions):
516  """Broadcasts rt_input to the ragged shape `dst_shape`."""
517  # Check that rt_input and dst_shape have the same row_splits dtype.
518  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
519      rt_input.row_splits.dtype != dst_shape.dim_size_dtype):
520    if not ragged_config.auto_cast_partition_dtype():
521      raise ValueError('rt_input and dst_shape have different row_split '
522                       'dtypes; use RaggedTensor.with_row_splits_dtype() or '
523                       'RaggedTensorDynamicShape.with_dim_size_dtype() to '
524                       'convert to a compatible dtype.')
525    rt_input = rt_input.with_row_splits_dtype(dtypes.int64)
526    dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64)
527
528  # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's
529  if rt_input.shape.ndims is None or dst_shape.rank is None:
530    raise ValueError('Unable to broadcast: unknown rank')
531  if rt_input.shape.ndims > dst_shape.rank:
532    raise ValueError('Incompatible with shape: rank mismatch')
533  if (isinstance(rt_input, ragged_tensor.RaggedTensor) and
534      rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions):
535    raise ValueError('Incompatible with shape: ragged rank mismatch')
536
537  src_shape = RaggedTensorDynamicShape.from_tensor(rt_input)
538  src_shape = src_shape.broadcast_to_rank(dst_shape.rank)
539
540  # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape.
541  if dst_shape.rank > rt_input.shape.ndims:
542    if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1:
543      rt_input = array_ops.reshape(
544          rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0))
545    for _ in range(dst_shape.rank - rt_input.shape.ndims):
546      if ragged_tensor.is_ragged(rt_input):
547        nrows = rt_input.nrows()
548      else:
549        nrows = array_ops.shape(rt_input,
550                                out_type=dst_shape.dim_size_dtype)[0]
551      rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows],
552                                                             validate=False)
553
554  # Add ragged dimensions to match dst_shape.
555  if ragged_tensor.is_ragged(rt_input):
556    inner_rank_diff = (
557        rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions)
558    if inner_rank_diff > 0:
559      rt_input = rt_input.with_flat_values(
560          ragged_tensor.RaggedTensor.from_tensor(
561              rt_input.flat_values, ragged_rank=inner_rank_diff,
562              row_splits_dtype=dst_shape.dim_size_dtype))
563  else:
564    rt_input = ragged_tensor.RaggedTensor.from_tensor(
565        rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1,
566        row_splits_dtype=dst_shape.dim_size_dtype)
567
568  # Do broadcasting for any dimensions that will remain uniform.  We can do
569  # these all at once, since they're independent of one another.
570  multiples = [1] * dst_shape.rank
571  for axis in range(dst_shape.num_partitioned_dimensions):
572    if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis):
573      src_size = src_shape.dimension_size(axis)
574      dst_size = dst_shape.dimension_size(axis)
575      if ((tensor_util.constant_value(src_size) in (1, None)) and
576          (tensor_util.constant_value(dst_size) != 1)):
577        multiples[axis] = array_ops.where(
578            math_ops.equal(src_size, 1), dst_size, 1)
579  if not all(isinstance(v, int) and v == 1 for v in multiples):
580    multiples = array_ops.stack(multiples, axis=0)
581    rt_input = ragged_array_ops.tile(rt_input, multiples)
582
583  if broadcast_inner_dimensions:
584    new_shape = array_ops.broadcast_dynamic_shape(
585        array_ops.shape(
586            rt_input.flat_values, out_type=dst_shape.dim_size_dtype),
587        array_ops.concat([[1], dst_shape.inner_dim_sizes], axis=0))
588    rt_input = rt_input.with_flat_values(
589        array_ops.broadcast_to(rt_input.flat_values, new_shape))
590
591  # Do broadcasting for dimensions that become ragged.  We must do these from
592  # outermost to innermost.
593  for axis in range(dst_shape.num_partitioned_dimensions):
594    if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis):
595      dst_size = dst_shape.dimension_size(axis)
596      rt_input = _ragged_tile_axis(rt_input, axis, dst_size,
597                                   dst_shape.dim_size_dtype)
598
599  return rt_input
600
601
602def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype):
603  """Tile a dimension of a RaggedTensor to match a ragged shape."""
604  assert axis > 0  # Outermost dimension may not be ragged.
605
606  if not ragged_tensor.is_ragged(rt_input):
607    rt_input = ragged_tensor.RaggedTensor.from_tensor(
608        rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype)
609
610  if axis > 1:
611    return rt_input.with_values(
612        _ragged_tile_axis(rt_input.values, axis - 1, repeats,
613                          row_splits_dtype))
614  else:
615    src_row_splits = rt_input.nested_row_splits
616    src_row_lengths = rt_input.nested_row_lengths()
617    splits = src_row_splits[0]
618
619    dst_row_lengths = [repeats]
620    for i in range(1, len(src_row_lengths)):
621      dst_row_lengths.append(
622          ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats))
623      splits = array_ops.gather(src_row_splits[i], splits)
624    dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits,
625                                           repeats)
626    return ragged_tensor.RaggedTensor.from_nested_row_lengths(
627        dst_values, dst_row_lengths, validate=False)
628