xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/ragged/dynamic_ragged_shape.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shapes & broadcasting for RaggedTensors.
16
17TODO(martinz): make this suitable for output for tf.shape
18TODO(martinz): replace ragged_tensor_shape with this.
19"""
20
21
22import abc
23from typing import Any, Iterable, Optional, Sequence, Tuple, Union
24
25import numpy as np
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import extension_type
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import check_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops.ragged import ragged_tensor
38from tensorflow.python.ops.ragged.row_partition import RowPartition
39from tensorflow.python.ops.ragged.row_partition import RowPartitionSpec
40from tensorflow.python.types import core
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.tf_export import tf_export
43
44
45class _DynamicRaggedShapeBatchEncoder(extension_type.ExtensionTypeBatchEncoder):
46  """A batch encoder for DynamicRaggedShape below."""
47
48  def batch(self, spec: "DynamicRaggedShape.Spec",
49            batch_size) -> "DynamicRaggedShape.Spec":
50    if spec.num_row_partitions:
51      new_head = _batch_rp_spec_head(spec._row_partitions[0], batch_size)  # pylint:disable=protected-access
52      new_tail = [_batch_rp_spec(rp, batch_size) for rp in spec._row_partitions]  # pylint:disable=protected-access
53      new_rp = [new_head] + new_tail
54      new_static_inner_shape = _batch_static_inner_shape(
55          spec._static_inner_shape, batch_size)  # pylint:disable=protected-access
56
57      return DynamicRaggedShape.Spec(
58          row_partitions=new_rp,
59          static_inner_shape=new_static_inner_shape,
60          dtype=spec.dtype)
61    elif batch_size is None:
62      if spec.inner_rank == 0:
63        return DynamicRaggedShape.Spec._from_tensor_shape([None],  # pylint:disable=protected-access
64                                                          0,
65                                                          dtype=spec.dtype)
66      else:
67        # Might be None
68        new_head = RowPartitionSpec(uniform_row_length=spec._dimension(0),  # pylint:disable=protected-access
69                                    dtype=spec.dtype)
70        new_static_inner_shape = _batch_static_inner_shape(
71            spec._static_inner_shape, batch_size)  # pylint:disable=protected-access
72        return DynamicRaggedShape.Spec(
73            row_partitions=[new_head],
74            static_inner_shape=new_static_inner_shape,
75            dtype=spec.dtype)
76    else:
77
78      return DynamicRaggedShape.Spec(
79          row_partitions=[],
80          static_inner_shape=_batch_tensor_shape(spec._static_inner_shape,  # pylint:disable=protected-access
81                                                 batch_size),
82          dtype=spec.dtype)
83
84  def unbatch(self,
85              spec: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape.Spec":
86    if spec.num_row_partitions:
87      result = []
88      head = spec._row_partitions[0]  # pylint:disable=protected-access
89      scale = None if head.uniform_row_length is None else head.nrows
90
91      for rp in spec._row_partitions[1:]:  # pylint:disable=protected-access
92        if scale is None:
93          result.append(
94              RowPartitionSpec(
95                  nrows=None,
96                  nvals=None,
97                  uniform_row_length=rp.uniform_row_length,
98                  dtype=spec.dtype))
99        else:
100          nrows = None if rp.nrows is None else rp.nrows//scale
101          if rp.uniform_row_length is None:
102            scale = None
103            result.append(RowPartitionSpec(nrows=nrows,
104                                           nvals=None,
105                                           uniform_row_length=None,
106                                           dtype=spec.dtype))
107          else:
108            result.append(
109                RowPartitionSpec(
110                    nrows=nrows,
111                    nvals=rp.nvals // scale,
112                    uniform_row_length=rp.uniform_row_length,
113                    dtype=spec.dtype))
114      return DynamicRaggedShape.Spec(
115          row_partitions=result,
116          static_inner_shape=_unbatch_static_inner_shape(
117              spec._static_inner_shape, scale),  # pylint:disable=protected-access
118          dtype=spec.dtype)
119    else:  # spec.num_row_partitions == 0
120      return DynamicRaggedShape.Spec(
121          row_partitions=[],
122          static_inner_shape=spec._static_inner_shape[1:],  # pylint:disable=protected-access
123          dtype=spec.dtype)
124
125  def decode(self, spec: "DynamicRaggedShape.Spec", encoding
126             ) -> "DynamicRaggedShape":
127    return DynamicRaggedShape.from_tensor(encoding, dtype=spec.dtype)
128
129  def encode(self, spec: "DynamicRaggedShape.Spec", value, minimum_rank=0
130             ) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]:
131    return ones(value, dtype=dtypes.bool)
132
133  def encoding_specs(
134      self,
135      spec: "DynamicRaggedShape.Spec"
136      ) -> Union[ragged_tensor.RaggedTensorSpec, tensor_spec.TensorSpec]:
137    if spec.rank != 0:
138      ragged_rank = spec.num_row_partitions
139    else:
140      # special case: need to unbatch twice to get ragged tensor.
141      ragged_rank = -1
142    return ragged_tensor.RaggedTensorSpec(
143        shape=spec._to_tensor_shape(),  # pylint:disable=protected-access
144        dtype=dtypes.bool,
145        ragged_rank=ragged_rank,
146        row_splits_dtype=spec.dtype)
147
148
149# TODO(martinz): allow inner_shape to be a fully defined TensorShape.
150# A "fully defined TensorShape" means one where the rank and all dimensions are
151# known.
152# Allowing inner_shape might mean allowing inner_shape to be initialized by
153# a fully defined TensorShape, or it might mean that you can actually store
154# TensorShape in the inner_shape field. This could conceivably construct
155# a DynamicRaggedShape that was dtype agnostic.
156#
157# TODO(martinz): unify the impl of the determination of index type across
158#     RowPartition and DynamicRaggedShape.
159@tf_export("experimental.DynamicRaggedShape")
160class DynamicRaggedShape(extension_type.BatchableExtensionType):
161  """The shape of a ragged or dense tensor.
162
163  Ragged shapes are encoded using two fields:
164
165  * `inner_shape`: An integer vector giving the shape of a dense tensor.
166  * `row_partitions`: A list of `RowPartition` objects, describing how
167    that flat shape should be partitioned to add ragged axes.
168
169  If a DynamicRaggedShape is the shape of a RaggedTensor rt, then:
170  1. row_partitions = rt._nested_row_partitions
171     (and thus len(row_partitions) > 0)
172  2. inner_shape is the shape of rt.flat_values
173
174  If a DynamicRaggedShape is the shape of a dense tensor t, then:
175  1. row_partitions = []
176  2. inner_shape is the shape of t.
177
178  Examples:
179
180  The following table gives a few examples (where `RP(lengths)` is short
181  for `RowPartition.from_lengths(lengths)`):
182
183  Row Partitions              | Inner Shape  | Example Tensor
184  --------------------------- | ------------ | ----------------------------
185  []                          | [2, 3]       | `[[1, 2, 3], [4, 5, 6]]`
186  [RP([2, 0, 3])]             | [5]          | `[[1, 2], [], [3, 4, 5]]`
187  [RP([2, 1])]                | [3, 2]       | `[[[1, 2], [3, 4]], [[5, 6]]]`
188  [RP([2, 1]), RP([2, 1, 2])] | [5]          | `[[[1, 2], [3]], [[4, 5]]]`
189  """
190  _row_partitions: Tuple[RowPartition, ...]
191  _inner_shape: ops.Tensor
192  _static_inner_shape: tensor_shape.TensorShape
193  __batch_encoder__ = _DynamicRaggedShapeBatchEncoder()
194  __name__ = "tf.DynamicRaggedShape"
195
196  def __init__(self,
197               row_partitions: Sequence[RowPartition],
198               inner_shape: core.TensorLike,
199               dtype: Optional[dtypes.DType] = None,
200               validate: bool = False,
201               static_inner_shape: ... = None):
202    """Core constructor for a DynamicRaggedShape.
203
204    Create a DynamicRaggedShape. This can be used to construct a
205    DynamicRaggedShape representing a ragged or dense shape. If row_partitions
206    is an empty list, then this is equivalent to a dense shape.
207
208    If row_partitions is specified, then the num_row_partitions will be equal
209    to len(row_partitions). There are several checks made.
210    Specifically:
211    1. Consecutive row_partitions must have consistent nvals and nrows.
212    2. The last row_partitions must have nvals equal to the first element of
213       inner_shape.
214
215    The inner_shape is converted to a tensor.
216    All row_partitions and the inner_shape are converted to the same dtype
217    (int64 or int32).
218
219    Args:
220      row_partitions: the row_partitions of the shape.
221      inner_shape: if len(row_partitions) > 0, the shape of the flat_values.
222        Otherwise, the shape of the tensor.
223      dtype: tf.int64, tf.int32, or None representing the preferred dtype.
224      validate: if true, dynamic validation is applied to the shape.
225      static_inner_shape: if len(row_partitions) > 0, the static shape of the
226        flat_values. Otherwise, the static shape of the tensor.
227        Should be convertible to a TensorShape.
228    """
229    if not isinstance(row_partitions, Iterable):
230      raise TypeError(
231          "row_partitions should be a list of row partitions. Instead, got " +
232          str(row_partitions))
233    for x in row_partitions:
234      if not isinstance(x, RowPartition):
235        raise TypeError("row_partitions contains " + str(x) +
236                        " which is not a RowPartition")
237    dtype = _find_dtype_iterable(row_partitions, dtype)
238    dtype = _find_dtype(inner_shape, dtype)
239    if (isinstance(inner_shape, np.ndarray) and
240        inner_shape.dtype == np.int32 and dtype is None):
241      dtype = dtypes.int32
242    dtype = _find_dtype(dtypes.int64, dtype)
243
244    row_partitions = tuple([rp.with_dtype(dtype) for rp in row_partitions])
245    self._row_partitions = row_partitions
246    self._inner_shape = ops.convert_to_tensor(
247        inner_shape, dtype_hint=dtype, name="inner_dim_sizes")
248    if self._inner_shape.dtype != dtype:
249      self._inner_shape = math_ops.cast(self._inner_shape, dtype)
250
251    checks = []
252    # Validate shapes.
253    if self._row_partitions:
254      for axis, rp in enumerate(self._row_partitions):
255        if axis > 0:
256          previous_row_partition = self._row_partitions[axis - 1]
257          msg = ("RowPartitions in DynamicRaggedShape do not align "
258                 f"between {axis - 1} and {axis}")
259          static_nrows = rp.static_nrows
260          static_nvals = previous_row_partition.static_nvals
261          if (static_nrows is not None) and (static_nvals is not None):
262            if static_nrows != static_nvals:
263              raise ValueError(msg)
264            else:
265              continue
266          if validate:
267            checks.append(
268                check_ops.assert_equal(
269                    previous_row_partition.nvals(),
270                    rp.nrows(),
271                    message=msg))
272
273    self._inner_shape.shape.assert_has_rank(1)
274
275    self._static_inner_shape = tensor_util.constant_value_as_shape(
276        self._inner_shape)
277    if static_inner_shape is not None:
278      self._static_inner_shape = self._static_inner_shape.merge_with(
279          static_inner_shape)
280
281    if row_partitions:
282      last_row_partition = row_partitions[-1]
283      static_nvals = last_row_partition.static_nvals
284      static_inner_shape_nvals = tensor_shape.dimension_value(
285          self._static_inner_shape[0])
286      if static_nvals is not None and static_inner_shape_nvals is not None:
287        if static_nvals != static_inner_shape_nvals:
288          raise ValueError("Last row partition does not match inner_shape.")
289      elif validate:
290        checks.append(
291            check_ops.assert_equal(
292                last_row_partition.nvals(),
293                self._inner_shape[0],
294                message="Last row partition does not match inner_shape."))
295    if checks:
296      self._inner_shape = control_flow_ops.with_dependencies(
297          checks, self._inner_shape, name="inner_shape_validated")
298      self._row_partitions = [
299          rp._with_dependencies(checks) for rp in self._row_partitions  # pylint: disable=protected-access
300      ]
301
302  @classmethod
303  def from_lengths(cls,
304                   lengths: Sequence[Union[Sequence[int], int]],
305                   num_row_partitions=None,
306                   dtype=dtypes.int64):
307    """Creates a shape with the given lengths and num_row_partitions.
308
309    The lengths can either be a nonnegative int or a list of nonnegative ints.
310
311    If num_row_partitions is None, then the minimal num_row_partitions is used.
312
313    For example, [2, (3, 2)] is the shape of [[0, 0, 0], [0, 0]], and
314    [2, 2] is the shape of [[0, 0], [0, 0]]
315
316    This chooses the minimal num_row_partitions required (including zero).
317
318    The following table gives a few examples (where `RP(lengths)` is short
319    for `RowPartition.from_lengths(lengths)`):
320
321    For example:
322    from_lengths           | row_partitions            | inner_shape
323    ---------------------- | --------------------------| -------------
324    []                     | []                        | []
325    [2, (3, 2)]            | [RP([3, 2])]              | [5]
326    [2, 2]                 | []                        | [2, 2]
327    [2, (3, 2), 7]         | [RP([3, 2])]              | [5, 7]
328    [2, (2, 2), 3]         | [RP([2, 2])]              | [4, 3]
329    [2, 2, 3]              | []                        | [2, 2, 3]
330    [2, (2, 1), (2, 0, 3)] | [RP(2, 1), RP([2, 0, 3])] | [5]
331
332    If we want the row partitions to end with uniform row partitions, then
333    we can set num_row_partitions.
334
335    For example,
336    below URP(3, 12) is RowPartition.from_uniform_row_length(3, 12)
337
338    from_lengths   | num_row_partitions | row_partitions           | inner_shape
339    ---------------| -------------------|--------------------------|------------
340    [2, (3, 2), 2] | 2                  | [RP([3, 2]), URP(2, 10)] | [10]
341    [2, 2]         | 1                  | [URP(2, 4)]              | [4]
342    [2, 2, 3]      | 0                  | []                       | [2, 2, 3]
343    [2, 2, 3]      | 1                  | [URP(2, 4)]              | [4, 3]
344    [2, 2, 3]      | 2                  | [URP(2, 4), URP(3, 12)]  | [12]
345
346
347
348    Representing the shapes from init():
349
350    from_lengths             | Tensor Example
351    ------------------------ | ------------------------------
352    `[2, 3]`                 | `[[1, 2, 3], [4, 5, 6]]`
353    `[3, (2, 0, 3)]`         | `[[1, 2], [], [3, 4, 5]]`
354    `[2, (2, 1), 2]`         | `[[[1, 2], [3, 4]], [[5, 6]]]`
355    `[2, (2, 1), (2, 1, 2)]` | `[[[1, 2], [3]], [[4, 5]]]`
356
357    Args:
358      lengths: the lengths of sublists along each axis.
359      num_row_partitions: the num_row_partitions of the result or None
360      indicating the minimum number of row_partitions.
361      dtype: the dtype of the shape (tf.int32 or tf.int64).
362
363    Returns:
364      a new DynamicRaggedShape
365    """
366    if not isinstance(lengths, list):
367      raise ValueError("lengths should be a list")
368    for x in lengths:
369      if not _is_int_or_tuple_of_ints(x):
370        raise ValueError(
371            "element of lengths should be int or tuple of ints: instead %r" %
372            (x,))
373
374    if num_row_partitions is None:
375      # Calculate the minimal num_row_partitions.
376      is_list = [not isinstance(x, int) for x in lengths]
377      if any(is_list):
378        # Last index when not a list.
379        num_row_partitions = len(is_list) - is_list[-1::-1].index(True) - 1
380      else:
381        num_row_partitions = 0
382
383    if not isinstance(num_row_partitions, int):
384      raise ValueError("num_row_partitions should be an int or None")
385
386    if not lengths:
387      if num_row_partitions > 0:
388        raise ValueError("num_row_partitions==0 for a scalar shape")
389      return DynamicRaggedShape([], [], dtype=dtype)
390
391    if not num_row_partitions < len(lengths):
392      raise ValueError(
393          "num_row_partitions should be less than `len(lengths)` "
394          "if shape is not scalar."
395      )
396
397    if num_row_partitions > 0:
398      (row_partitions, nvals) = _to_row_partitions_and_nvals_from_lengths(
399          lengths[:num_row_partitions + 1])
400      inner_shape = [nvals] + lengths[num_row_partitions + 1:]
401      return DynamicRaggedShape(
402          row_partitions, inner_shape, dtype=dtype)
403    else:
404      return DynamicRaggedShape([], lengths, dtype=dtype)
405
406  @classmethod
407  def from_row_partitions(cls, row_partitions, dtype=None):
408    """Create a shape from row_partitions.
409
410    Args:
411      row_partitions: a nonempty list of RowPartition objects.
412      dtype: the dtype to use, or None to use the row_partitions dtype.
413
414    Returns:
415      a DynamicRaggedShape with inner_rank==1.
416    """
417    if not row_partitions:
418      raise ValueError("row_partitions cannot be empty")
419    inner_shape = [row_partitions[-1].nvals()]
420    return DynamicRaggedShape(
421        row_partitions, inner_shape, dtype=dtype)
422
423  @classmethod
424  def _from_inner_shape(cls, inner_shape, dtype=None):
425    """Create a shape from inner_shape, where num_row_partitions == 0."""
426    return DynamicRaggedShape([], inner_shape, dtype=dtype)
427
428  # pylint: disable=protected-access
429  @classmethod
430  def from_tensor(cls, t, dtype=None):
431    """Constructs a ragged shape for a potentially ragged tensor."""
432    if ragged_tensor.is_ragged(t):
433      return DynamicRaggedShape(
434          t._nested_row_partitions, _flat_values_shape(t), dtype=dtype)
435    else:
436      return DynamicRaggedShape._from_inner_shape(
437          array_ops.shape(t), dtype=dtype)
438
439  @property
440  def row_partitions(self):
441    """The row_partitions of the shape."""
442    return self._row_partitions
443
444  @property
445  def num_row_partitions(self):
446    """The number of row_partitions of the shape."""
447    return len(self._row_partitions)
448
449  @property
450  def dtype(self):
451    """The dtype of the shape -- one of tf.int32 or tf.int64."""
452    return self._inner_shape.dtype
453
454  def _static_inner_shape_as_list(self, truncate_first):
455    """Returns the lengths of the inner shape (if rank known), or [...]."""
456    if self._static_inner_shape.rank is None:
457      return [...]
458    result = self._static_inner_shape.as_list()
459    if truncate_first:
460      return result[1:]
461    return result
462
463  def static_lengths(self, ragged_lengths=True):
464    """Returns a list of statically known axis lengths.
465
466    This represents what values are known. For each row partition, it presents
467    either the uniform row length (if statically known),
468    the list of row lengths, or none if it is not statically known.
469    For the inner shape, if the rank is known, then each dimension is reported
470    if known, and None otherwise. If the rank of the inner shape is not known,
471    then the returned list ends with an ellipsis.
472
473    Args:
474      ragged_lengths: If false, returns None for all ragged dimensions.
475
476    Returns:
477      A Sequence[Union[Sequence[int],int, None]] of lengths, with a possible
478      Ellipsis at the end.
479    """
480    if self.num_row_partitions == 0:
481      return self._static_inner_shape_as_list(False)
482    first_dim = self.row_partitions[0].static_nrows
483    if isinstance(first_dim, tensor_shape.Dimension):
484      first_dim = first_dim.value
485    rp_dims = [first_dim]
486    for rp in self.row_partitions:
487      if rp.is_uniform():
488        rp_dims.append(rp.static_uniform_row_length)
489      elif ragged_lengths:
490        const_vals = tensor_util.constant_value(rp.row_lengths())
491        if const_vals is None:
492          rp_dims.append(None)
493        else:
494          rp_dims.append(tuple(const_vals.tolist()))
495      else:
496        rp_dims.append(None)
497
498    return rp_dims + self._static_inner_shape_as_list(True)
499
500  def __repr__(self):
501    lengths = _list_with_ellipsis_to_str(self.static_lengths())
502    return ("<DynamicRaggedShape "
503            "lengths=%s num_row_partitions=%r>" %
504            (lengths, self.num_row_partitions))
505
506  def _to_tensor_shape(self) -> tensor_shape.TensorShape:
507    """Returns a TensorShape representation of the shape."""
508    lengths = self.static_lengths(ragged_lengths=False)
509    if not lengths:
510      return tensor_shape.TensorShape(())
511    if lengths[-1] == Ellipsis:
512      return tensor_shape.TensorShape(None)
513    return tensor_shape.TensorShape(lengths)
514
515  def _slice_shape(self, start, stop):
516    """Returns a shape self[start:stop].
517
518    If start == 0, then this truncates dimensions after stop.
519    If start != 0, then this will return a shape with num_row_partitions == 0.
520
521    See __getitem__.
522
523    Args:
524      start: the first dimension. 0 <= start <= rank
525      stop: the last dimension (exclusive). 0 <= stop <= rank
526    """
527    if stop <= start:
528      return DynamicRaggedShape._from_inner_shape([])
529    elif start == 0:
530      if stop <= self.num_row_partitions:
531        if stop == 1:
532          return DynamicRaggedShape._from_inner_shape(
533              [self.row_partitions[0].nrows()])
534        new_row_partitions = self.row_partitions[:stop - 1]
535        new_inner_shape = [new_row_partitions[-1].nvals()]
536        return DynamicRaggedShape(new_row_partitions, new_inner_shape)
537      else:
538        if self.rank is None:
539          new_inner_rank = stop - self.num_row_partitions
540          new_inner_shape = self.inner_shape[:new_inner_rank]
541          return DynamicRaggedShape(
542              row_partitions=self.row_partitions,
543              inner_shape=new_inner_shape,
544              static_inner_shape=None,
545              validate=False)
546
547        elif self.rank <= stop:
548          return self
549        new_inner_rank = stop - self.num_row_partitions
550        new_inner_shape = self.inner_shape[:new_inner_rank]
551        return DynamicRaggedShape(
552            row_partitions=self.row_partitions,
553            inner_shape=new_inner_shape,
554            static_inner_shape=tensor_shape.TensorShape([None]
555                                                        * new_inner_rank),
556            validate=False)
557    else:
558      if self.rank is None or stop < self.rank:
559        partial = self._slice_shape(0, stop)
560      else:
561        partial = self
562
563      for x in partial.row_partitions:
564        if not x.is_uniform():
565          raise ValueError("All relevant dimensions must be uniform")
566      if partial.rank is None:
567        # TODO(martinz): Implement _with_num_row_partitions(0) if rank is
568        # unknown, and remove.
569        raise NotImplementedError(
570            "__getitem__[start:stop] where start > 0 not implemented")
571
572      return DynamicRaggedShape._from_inner_shape(
573          partial._with_num_row_partitions(0).inner_shape[start:])
574
575  def _dimension(self, index):
576    """Return a dimension, if the dimension is not ragged (see __getitem__)."""
577    rank = self.rank
578    if not isinstance(index, int):
579      raise TypeError("index should be an int")
580    if (self.num_row_partitions == 0 or index > self.num_row_partitions + 1):
581      # If num_row_partitions > 0 and index <= num_row_partitions + 1, then
582      # we are safe.
583      if rank is None:
584        raise ValueError(
585            "Rank must be known to use __getitem__ on a large index.")
586      if index >= rank:
587        raise IndexError("Index is too big: " + str(index) + ">=" + str(rank))
588    if index < 0:
589      raise IndexError("Index must be non-negative: " + str(index))
590    elif not self.is_uniform(index):
591      raise ValueError("Index " + str(index) + " is not uniform")
592    elif index == 0 and self.num_row_partitions > 0:
593      static_nrows = self.row_partitions[0].static_nrows
594      if static_nrows is not None:
595        return constant_op.constant(static_nrows, dtype=self.dtype)
596      return self.row_partitions[0].nrows()
597    elif self.num_row_partitions == 0:
598      static_result = tensor_shape.dimension_value(
599          self._static_inner_shape[index])
600      if static_result is not None:
601        return constant_op.constant(static_result, dtype=self.dtype)
602      return self.inner_shape[index]
603    elif index > self.num_row_partitions:
604      static_result = tensor_shape.dimension_value(
605          self._static_inner_shape[index - self.num_row_partitions])
606      if static_result is not None:
607        return constant_op.constant(static_result, dtype=self.dtype)
608
609      return self.inner_shape[index - self.num_row_partitions]
610    else:
611      return self.row_partitions[index - 1].uniform_row_length()
612
613  def __getitem__(self, index):
614    """Returns a dimension or a slice of the shape.
615
616    Ragged shapes can have ragged dimensions that depend upon other dimensions.
617    Therefore, if you ask for a dimension that is ragged, this function returns
618    a ValueError. For similar reasons, if a slice is selected that includes
619    a ragged dimension without including the zero dimension, then this fails.
620
621    Any slice that does not start at zero will return a shape
622    with num_row_partitions == 0.
623
624    Args:
625      index: the index: can be an int or a slice.
626
627    Raises:
628      IndexError: if the index is not in range.
629      ValueError: if the rank is unknown, or a ragged rank is requested
630      incorrectly.
631    """
632    rank = self.rank
633    if isinstance(index, slice):
634
635      if (index.step is not None) and (index.step != 1):
636        raise IndexError("Cannot stride through a shape")
637      start = index.start
638      stop = index.stop
639      if start is None:
640        start = 0
641      start = _fix_start_index(start, rank, self.num_row_partitions)
642      stop = _fix_stop_index(stop, rank)
643      return self._slice_shape(start, stop)
644    elif isinstance(index, int):
645      if index < 0:
646        if rank is None:
647          raise ValueError(
648              "Rank must be known to use __getitem__ with a negative index.")
649        return self._dimension(rank + index)
650      return self._dimension(index)
651    else:
652      raise TypeError("Argument is not an int or a slice")
653
654  def _num_elements(self):
655    """Number of elements in a shape.
656
657    Returns:
658      The number of elements in the shape.
659
660    """
661    return math_ops.reduce_prod(self.inner_shape)
662
663  def _num_slices_in_dimension(self, axis):
664    """The total size of a dimension (like nvals).
665
666    Effectively, this is self[:axis+1]._num_elements()
667
668    Example:
669    shape = DynamicRaggedShape._from_inner_shape([2, 3, 4])
670    shape._num_slices_in_dimension(0) = 2
671    shape._num_slices_in_dimension(1) = 6
672    shape._num_slices_in_dimension(2) = 24
673    shape._num_slices_in_dimension(-1) = 24
674    shape._num_slices_in_dimension(-2) = 6
675    shape._num_slices_in_dimension(-2) = 2
676
677    Args:
678      axis: the last axis to include in the number of elements. If negative,
679        then axis = axis + rank.
680
681    Returns:
682      The number of elements in the shape.
683    """
684    if not isinstance(axis, int):
685      raise TypeError("axis must be an integer")
686    if axis < 0:
687      rank = self.rank
688      if rank is None:
689        raise ValueError(
690            "You can't use negative values if the rank is undefined")
691      axis = axis + rank
692    if axis == 0:
693      return self._dimension(0)
694    if axis <= self.num_row_partitions:
695      return self.row_partitions[axis - 1].nvals()
696    # If self.num_row_partitions = 1, and
697    # self.inner_shape=[3,5,6], and axis=2, then you want:
698    # 15 = 3 * 5 = math_ops.reduce_prod(self.inner_shape[:2])
699    # 2 = axis - (self.num_row_partitions - 1)
700    # If num_row_partitions=0, and
701    # self.inner_shape=[3,5,6] and axis=2, then you want:
702    # 90 = 3 * 5 * 6 = math_ops.reduce_prod(self.inner_shape[:3])
703    # 3 = axis - (self.num_row_partitions - 1)
704    remainder = axis - (self.num_row_partitions - 1)
705    return _reduce_prod_patch(self.inner_shape[:remainder])
706
707  def is_uniform(self, axis):
708    """Returns true if the indicated dimension is uniform."""
709    if not isinstance(axis, int):
710      raise TypeError("axis must be an integer")
711    rank = self.rank
712    if axis < 0:
713      raise IndexError("Negative axis values are not supported")
714    elif rank is not None and axis >= rank:
715      raise IndexError("Expected axis=%s < rank=%s" % (axis, rank))
716    else:
717      return ((axis == 0 or axis > len(self._row_partitions))  # pylint:disable=superfluous-parens
718              or self._row_partitions[axis - 1].is_uniform())
719
720  @property
721  def rank(self):
722    """The number of dimensions in this shape, or None if unknown."""
723    inner_rank = self.inner_rank
724    if inner_rank is None:
725      return None
726    else:
727      return self.num_row_partitions + inner_rank
728
729  @property
730  def inner_shape(self):
731    """The inner dimension sizes for this shape.
732
733    Returns:
734      A 1-D integer `Tensor`.
735    """
736    return self._inner_shape
737
738  @property
739  def inner_rank(self):
740    """The rank of inner_shape."""
741    return tensor_shape.dimension_value(self._static_inner_shape.rank)
742
743  def _alt_inner_shape(self, new_inner_rank):
744    """Get an alternative inner shape with higher or lower rank.
745
746    For the rank of the inner shape to be be higher, the last few ragged
747    dimensions must have uniform_row_length.
748
749    Args:
750      new_inner_rank: the new rank of the inner_shape
751
752    Returns:
753       A new inner_shape of rank new_inner_rank.
754    """
755    if new_inner_rank == 0:
756      raise ValueError("new_inner_rank cannot be zero")
757    elif self.inner_rank == 0:
758      raise ValueError("old inner_rank cannot be zero")
759    elif new_inner_rank == self.inner_rank:
760      return self.inner_shape
761    elif new_inner_rank < self.inner_rank:
762      if self._static_inner_shape.is_fully_defined():
763        return _alt_inner_shape_from_tensor_shape(self._static_inner_shape,
764                                                  self.dtype, new_inner_rank)
765      first_dimension = self._num_slices_in_dimension(-new_inner_rank)
766      if new_inner_rank == 1:
767        return array_ops.expand_dims(first_dimension, 0)
768      remaining_dimensions = self.inner_shape[1 - new_inner_rank:]
769      return array_ops.concat(
770          [array_ops.expand_dims(first_dimension, 0), remaining_dimensions],
771          axis=0)
772    else:
773      assert new_inner_rank > self.inner_rank
774      new_dimensions = new_inner_rank - self.inner_rank
775      if any(
776          [not x.is_uniform() for x in self.row_partitions[-new_dimensions:]]):
777        raise ValueError("Cannot get an inner shape over a ragged dimension")
778      first_dimension = self._num_slices_in_dimension(-new_inner_rank)
779      new_dimensions = new_inner_rank - self.inner_rank
780      new_dims = [first_dimension] + [
781          x.uniform_row_length() for x in self.row_partitions[-new_dimensions:]
782      ]
783      return array_ops.concat([array_ops.stack(new_dims), self.inner_shape[1:]],
784                              axis=0)
785
786  def _inner_shape_dim(self, dimension):
787    """Returns an int or a tensor representing _inner_shape[dimension]."""
788    result = tensor_shape.dimension_value(self._static_inner_shape[dimension])
789    return self._inner_shape[dimension] if result is None else result
790
791  def _with_inner_rank(self, inner_rank):
792    """Returns the same shape but a different inner_rank.
793
794    All dimensions that are to be represented in the inner_shape must be dense.
795    See inner_rank.
796
797    Args:
798      inner_rank: the new inner_rank of the shape.
799
800    Returns:
801      the same shape but a different inner_rank
802
803    Raises:
804      ValueError if the new dense rank is invalid, or the old rank is unknown.
805    """
806    rank = self.rank
807    if rank is None:
808      raise ValueError("Rank must be known to adjust inner_rank")
809    elif rank < 2:
810      if inner_rank == rank:
811        return self
812      raise ValueError("Cannot change inner_rank if rank < 2")
813    else:
814      # When self.rank is not None:
815      # self.rank = self.inner_rank + self.num_row_partitions
816      new_num_row_partitions = rank - inner_rank
817      return self._with_num_row_partitions(new_num_row_partitions)
818
819  def _with_num_row_partitions(self, num_row_partitions):
820    """Creates an identical shape with the given num_row_partitions.
821
822    Note that the shape must be statically refactorable to this rank.
823    In particular:
824    * rank must be known.
825    * num_row_partitions must be a nonnegative int.
826    * num_row_partitions must be less than the rank of the shape
827    * num_row_partitions must be greater or equal to the index of any ragged
828    dimension.
829
830    Note that if the num_row_partitions is the same, self is returned.
831
832    Args:
833      num_row_partitions: the target num_row_partitions (must be a nonnegative
834        int).
835
836    Returns:
837      a shape with a (possibly) different num_row_partitions.
838
839    Raises:
840      ValueError: if the rank is unknown, the argument is not a nonnegative int,
841        or there is a dimension that is nonuniform.
842    """
843    rank = self.rank
844    if rank is None:
845      raise ValueError("Rank must be known to adjust num_row_partitions")
846    if not isinstance(num_row_partitions, int):
847      raise ValueError("num_row_partitions must be an int")
848    if num_row_partitions < 0:
849      raise ValueError("num_row_partitions must be nonnegative")
850    if num_row_partitions == self.num_row_partitions:
851      return self
852    if num_row_partitions >= rank:
853      raise ValueError("num_row_partitions must be less than rank")
854    if num_row_partitions > self.num_row_partitions:
855      num_row_partitions_diff = num_row_partitions - self.num_row_partitions
856      new_inner_rank = self.rank - num_row_partitions
857      nvals = self._inner_shape_dim(0)
858      more_rp = []
859      for i in range(num_row_partitions_diff):
860        nrows = nvals
861        row_length = self._inner_shape_dim(i + 1)
862        nvals = nrows * row_length
863        rp = RowPartition.from_uniform_row_length(
864            row_length, nrows=nrows, dtype=self.dtype)
865        more_rp.append(rp)
866      alt_inner = self._alt_inner_shape(new_inner_rank)
867      return DynamicRaggedShape(
868          list(self.row_partitions) + more_rp, alt_inner)
869    else:
870      assert num_row_partitions < self.num_row_partitions
871      return DynamicRaggedShape(
872          self.row_partitions[:num_row_partitions],
873          self._alt_inner_shape(self.rank - num_row_partitions))
874
875  def _merge_dims(self, outer_axis: int,
876                  inner_axis: int) -> "DynamicRaggedShape":
877    """Merges outer_axis...inner_axis into a single dimension.
878
879    Returns a copy of this shape with the specified range of dimensions
880    flattened into a single dimension, with elements in row-major order.
881
882    #### Examples:
883
884    >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1), (1,2,3)])._merge_dims(0, 1)  # pylint: disable=line-too-long
885    <DynamicRaggedShape lengths=[3, (1, 2, 3)] num_row_partitions=1>
886    >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1), (1,2,3)])._merge_dims(1, 2)  # pylint: disable=line-too-long
887    <DynamicRaggedShape lengths=[2, (3, 3)] num_row_partitions=1>
888    >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1), (1,2,3)])._merge_dims(0, 2)  # pylint: disable=line-too-long
889    <DynamicRaggedShape lengths=[6] num_row_partitions=0>
890
891    To mimic the behavior of `np.flatten` (which flattens all dimensions), use
892    `rt.merge_dims(0, -1).  To mimic the behavior of `tf.layers.Flatten` (which
893    flattens all dimensions except the outermost batch dimension), use
894    `rt.merge_dims(1, -1)`.
895
896    Args:
897      outer_axis: `int`: The first dimension in the range of dimensions to
898        merge. May be negative if `self.shape.rank` is statically known.
899      inner_axis: `int`: The last dimension in the range of dimensions to merge.
900        May be negative if `self.shape.rank` is statically known.
901
902    Returns:
903      A copy of this shape, with the specified dimensions merged into a
904      single dimension.  The returned shape will be
905      `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
906      is the total number of slices in the merged dimensions.
907    """
908    outer_axis = array_ops.get_positive_axis(
909        outer_axis,
910        self.rank,
911        axis_name="outer_axis",
912        ndims_name="rank(self)")
913    inner_axis = array_ops.get_positive_axis(
914        inner_axis,
915        self.rank,
916        axis_name="inner_axis",
917        ndims_name="rank(self)")
918    if not outer_axis <= inner_axis:
919      raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or "
920                       f"equal to inner_axis ({inner_axis}).")
921    if outer_axis == inner_axis:
922      return self
923    if self.num_row_partitions == 0:
924      # A dense tensor.
925      (new_inner_shape, new_static_inner_shape) = _merge_inner_shape(
926          self._inner_shape, self._static_inner_shape, outer_axis, inner_axis)
927      return DynamicRaggedShape([],
928                                new_inner_shape,
929                                dtype=self.dtype,
930                                static_inner_shape=new_static_inner_shape)
931    if inner_axis <= self.num_row_partitions:
932      # Here, we are merging the row_partitions,
933      # but the inner_shape is unchanged.
934      if outer_axis == 0:
935        # There is no need to merge axes before the first, just truncate them.
936        return DynamicRaggedShape(
937            self._row_partitions[inner_axis:],
938            self.inner_shape,
939            dtype=self.dtype,
940            static_inner_shape=self._static_inner_shape)
941      prefix_rp = self._row_partitions[:outer_axis - 1]
942      suffix_rp = self._row_partitions[inner_axis:]
943      internal_rp = self._row_partitions[outer_axis - 1:inner_axis]
944      new_rp = prefix_rp + (_merge_row_partitions(internal_rp),) + suffix_rp
945
946      return DynamicRaggedShape(
947          new_rp, self.inner_shape, dtype=self.dtype,
948          static_inner_shape=self._static_inner_shape)
949    elif outer_axis > self.num_row_partitions:
950      # In this scenario, only the inner_shape is changed.
951      # Example #1:
952      # if [2, (1, 2), 5, 3], num_row_partitions=1, outer_axis=2, inner_axis=3.
953      # Result: [2, (1, 2), 15], num_row_partitions=1, outer_axis=2,
954      #     inner_axis=3.
955      (new_inner_shape, new_static_inner_shape) = _merge_inner_shape(
956          self._inner_shape, self._static_inner_shape,
957          outer_axis-self.num_row_partitions,
958          inner_axis-self.num_row_partitions)
959      return DynamicRaggedShape(
960          self._row_partitions,
961          new_inner_shape, dtype=self.dtype,
962          static_inner_shape=new_static_inner_shape)
963    else:
964      # Here, both inner_shape and row_partitions are changed.
965      rank = self.rank
966      if rank is None:
967        raise ValueError("Cannot merge_dims of the inner shape if the " +
968                         "dimension of inner_shape is unknown")
969      if outer_axis == 0:
970        new_inner_shape = self._alt_inner_shape(rank - inner_axis)
971        return DynamicRaggedShape._from_inner_shape(new_inner_shape)
972      else:
973        prefix = self._row_partitions[:outer_axis-1]
974        suffix = _merge_row_partitions(self._row_partitions[outer_axis-1:])
975        new_inner_shape = self._alt_inner_shape(rank - inner_axis)
976        num_merged_inner = inner_axis - self.num_row_partitions
977        prod = _reduce_prod_patch(self._inner_shape[1:num_merged_inner + 1])
978        tail_suffix = RowPartition.from_row_splits(suffix.row_splits() * prod)
979        return DynamicRaggedShape(prefix + (tail_suffix,), new_inner_shape)
980
981  def with_dtype(self, dtype):
982    """Change the dtype of the shape."""
983    if dtype == self.dtype:
984      return self
985    else:
986      return DynamicRaggedShape(
987          self.row_partitions, self.inner_shape, dtype=dtype)
988
989  def _merge_with(self, other: "DynamicRaggedShape") -> "DynamicRaggedShape":
990    """Merge two shapes that are equal modulo num_row_partitions.
991
992    The resulting num_row_partitions is the maximum of the two
993    num_row_partitions.
994
995    Args:
996      other: a DynamicRaggedShape representing the same shape with a possibly
997      different number of row partitions.
998
999    Returns:
1000      A DynamicRaggedShape with the same shape and the maximum of the
1001      num_row_partitions of the two shapes.
1002    """
1003    max_num_row_partitions = max(self.num_row_partitions,
1004                                 other.num_row_partitions)
1005    a = self._with_num_row_partitions(max_num_row_partitions)
1006    b = other._with_num_row_partitions(max_num_row_partitions)
1007    new_row_partitions = [
1008        rp_a._merge_precomputed_encodings(rp_b)
1009        for (rp_a, rp_b) in zip(a._row_partitions, b._row_partitions)
1010    ]
1011    new_dtype = b.dtype if a.dtype == dtypes.int32 else dtypes.int64
1012
1013    new_static_inner_shape = a._static_inner_shape.merge_with(
1014        b._static_inner_shape)
1015    new_inner_shape = a._inner_shape
1016    return DynamicRaggedShape(new_row_partitions, new_inner_shape, new_dtype,
1017                              True, new_static_inner_shape)
1018
1019  def _merge_with_spec(
1020      self, other: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape":
1021    """Merge a spec with a DynamicRaggedShape."""
1022    # TODO(martinz): add tests for dynamic inconsistencies.
1023    max_num_row_partitions = max(self.num_row_partitions,
1024                                 other.num_row_partitions)
1025    a = self._with_num_row_partitions(max_num_row_partitions)
1026    b = other._with_num_row_partitions(max_num_row_partitions)
1027    new_row_partitions = [rp_a._merge_with_spec(rp_b) for (rp_a, rp_b) in
1028                          zip(a._row_partitions, b._row_partitions)]
1029    new_dtype = b.dtype if a.dtype == dtypes.int32 else dtypes.int64
1030
1031    new_static_inner_shape = a._static_inner_shape.merge_with(
1032        b._static_inner_shape)
1033    new_inner_shape = a._inner_shape
1034    return DynamicRaggedShape(
1035        new_row_partitions,
1036        new_inner_shape,
1037        new_dtype,
1038        True,
1039        new_static_inner_shape)
1040
1041  def _as_row_partitions(self):
1042    """Returns row partitions representing this shape.
1043
1044    In order to represent a shape as row partitions, the rank of the shape
1045    must be known, and the shape must have rank at least one.
1046
1047    Returns:
1048      A list of RowPartition objects.
1049    Raises:
1050      ValueError, if the shape cannot be represented by RowPartitions.
1051    """
1052    rank = self.rank
1053    if rank is None:
1054      raise ValueError("rank must be known for _as_row_partitions")
1055    elif rank < 1:
1056      raise ValueError("rank must be >= 1 for _as_row_partitions")
1057    fully_ragged = self._with_num_row_partitions(rank - 1)
1058    return fully_ragged.row_partitions
1059
1060  def _validate_flat_values_dynamically(self, flat_values):
1061    """Test if flat_values have the right nvals dynamically."""
1062    if self.row_partitions:
1063      assert_op = check_ops.assert_equal(
1064          self.row_partitions[-1].nvals(),
1065          array_ops.shape(flat_values, out_type=self.dtype)[0],
1066          message="Last row partition does not match flat_values.")
1067      return control_flow_ops.with_dependencies([assert_op], flat_values)
1068    return flat_values
1069
1070  def _validate_flat_values(self, flat_values):
1071    """Test if flat_values have the right nvals."""
1072    if not isinstance(flat_values, ops.Tensor):
1073      return flat_values
1074    if self.row_partitions:
1075      last_row_partition = self.row_partitions[-1]
1076      flat_values_shape = flat_values.shape
1077      if flat_values_shape is None:
1078        return self._validate_flat_values_dynamically(flat_values)
1079      first_dim_flat_values = flat_values_shape[0]
1080      if isinstance(first_dim_flat_values, tensor_shape.Dimension):
1081        first_dim_flat_values = first_dim_flat_values.value
1082      if first_dim_flat_values is None:
1083        return self._validate_flat_values_dynamically(flat_values)
1084      static_nvals = last_row_partition.static_nvals
1085      if static_nvals is None:
1086        return self._validate_flat_values_dynamically(flat_values)
1087      if first_dim_flat_values != static_nvals:
1088        raise ValueError("Last row partition does not match flat_values.")
1089    return flat_values
1090
1091  def _add_row_partitions(self, flat_values, validate=False):
1092    """Add row partitions to flat_values, if necessary.
1093
1094    If the shape is truly ragged, then this adds the row_partitions.
1095
1096    The shape is dense, then this just returns flat_values.
1097
1098    Args:
1099      flat_values: the flat_values of a ragged tensor with this shape, or a
1100        dense tensor with this shape.
1101      validate: validate the flat_values have the right first dimension.
1102
1103    Returns:
1104      flat_values reshaped to have row_partitions.
1105    """
1106    if self.row_partitions:
1107      if validate:
1108        flat_values = self._validate_flat_values(flat_values)
1109      return ragged_tensor.RaggedTensor._from_nested_row_partitions(
1110          flat_values, self.row_partitions, validate=False)
1111    else:
1112      return flat_values
1113
1114  class Spec:
1115    """A Spec for DynamicRaggedShape: similar to a static shape."""
1116
1117    def __init__(self, row_partitions: Tuple[RowPartitionSpec, ...],
1118                 static_inner_shape: tensor_shape.TensorShape,
1119                 dtype: dtypes.DType):
1120      """Create a Spec given row partitions, a static inner shape, and a dtype.
1121
1122      Args:
1123        row_partitions: A sequence of `RowPartitionSpec`s describing how the
1124            ragged shape is partitioned.
1125        static_inner_shape: The static shape of the flat_values.
1126        dtype: The DType used to encode the shape (tf.int64 or tf.int32).
1127      """
1128      # Independent validation and coercion of each argument.
1129      if not isinstance(row_partitions, Iterable):
1130        raise TypeError("row_partitions should be an Iterable")
1131
1132      row_partitions = tuple(row_partitions)
1133
1134      static_inner_shape = tensor_shape.as_shape(static_inner_shape)
1135
1136      dtype = dtypes.as_dtype(dtype)
1137
1138      if not all(isinstance(rp, RowPartitionSpec) for rp in row_partitions):
1139        raise TypeError(
1140            "row_partitions should be an Iterable of RowPartitionSpecs")
1141
1142      if dtype != dtypes.int32 and dtype != dtypes.int64:
1143        raise ValueError("dtype must be tf.int32 or tf.int64")
1144
1145      # All fields are now typechecked and internally consistent.
1146      for spec in row_partitions:
1147        if spec.dtype != dtype:
1148          raise ValueError(
1149              f"dtype of {spec!r} is {spec.dtype!r}: expected {dtype!r}")
1150
1151      row_partitions = tuple(row_partitions)
1152
1153      inner_rank = static_inner_shape.rank
1154
1155      if inner_rank == 0:
1156        if row_partitions:
1157          raise ValueError(
1158              "If row_partitions are provided, must have inner_rank > 0")
1159      else:
1160        num_slices_in_dimension = []   # type: Sequence[tensor_shape.Dimension]
1161
1162        # We first attempt to calculate num_slices_in_dimension through a
1163        # forward pass, using nrows[k] = nrows[k-1] * uniform_row_length
1164        # and other tricks.
1165        for i in range(len(row_partitions)):
1166          rp = row_partitions[i]
1167          result = tensor_shape.Dimension(rp.nrows)
1168          if i > 0:
1169            previous_rp = row_partitions[i - 1]
1170            result = result.merge_with(previous_rp.nvals)
1171            result = result.merge_with(num_slices_in_dimension[-1] *
1172                                       previous_rp.uniform_row_length)
1173          num_slices_in_dimension.append(result)
1174        # In the last step of the forward pass,
1175        # we combine nvals and the first dimension in static_inner_shape.
1176        if row_partitions:
1177          last_rp = row_partitions[-1]
1178          result = (num_slices_in_dimension[-1] *
1179                    last_rp.uniform_row_length).merge_with(last_rp.nvals)
1180          if inner_rank is not None:
1181            result = result.merge_with(
1182                tensor_shape.dimension_at_index(static_inner_shape, 0))
1183            static_inner_shape = result + static_inner_shape[1:]
1184          num_slices_in_dimension.append(result)
1185
1186        # Now, we start a backward pass.
1187        for i in range(len(num_slices_in_dimension) - 1, 0, -1):
1188          num_slices_in_dimension[i - 1] = num_slices_in_dimension[
1189              i - 1].merge_with(
1190                  _safe_floor_div(num_slices_in_dimension[i],
1191                                  row_partitions[i - 1].uniform_row_length))
1192
1193        # Finally, we construct the partitions.
1194        row_partitions = [
1195            RowPartitionSpec(  # pylint: disable=g-complex-comprehension
1196                nrows=num_slices_in_dimension[i].value,
1197                uniform_row_length=rp.uniform_row_length,
1198                nvals=num_slices_in_dimension[i + 1].value,
1199                dtype=rp.dtype) for i, rp in enumerate(row_partitions)
1200        ]
1201
1202      self._static_inner_shape = static_inner_shape
1203      self._inner_shape = tensor_spec.TensorSpec(
1204          [inner_rank], dtype=dtype)
1205      self._row_partitions = row_partitions
1206
1207    def __repr__(self):
1208      return (
1209          f"DynamicRaggedShape.Spec(row_partitions={self._row_partitions!r}, " +
1210          f"static_inner_shape={self._static_inner_shape!r}, " +
1211          f"dtype={self.dtype!r})")
1212
1213    @classmethod
1214    def from_value(cls, value: Any) -> "DynamicRaggedShape.Spec":
1215      """Create a Spec from a DynamicRaggedShape."""
1216      # super().from_value(...) creates an object, but there is no validation.
1217      # No methods can be trusted on the object, just the properties.
1218      initial = super(DynamicRaggedShape.Spec, cls).from_value(value)
1219
1220      # However, since value is a DynamicRaggedShape, we
1221      # can guarantee that initial._inner_shape.shape.rank == 1
1222
1223      # Moreover, if inner_shape.shape[0] is not None, then
1224      # static_inner_shape.rank is not None.
1225
1226      return DynamicRaggedShape.Spec(
1227          row_partitions=initial._row_partitions,
1228          static_inner_shape=initial._static_inner_shape,
1229          dtype=initial._inner_shape.dtype)
1230
1231    # TODO(martinz): it is unclear what the default uniformity of RowPartitions
1232    # should be, so I am moving this to experimental until we figure it out.
1233    # Also, while I have specified this is meant to represent a shape of a
1234    # proper Tensor instead of a RaggedTensor, this is also subject to
1235    # interpretation.
1236    @classmethod
1237    def _from_tensor_shape(cls,
1238                           shape: Any,
1239                           num_row_partitions: int,
1240                           dtype: dtypes.DType) -> "DynamicRaggedShape.Spec":
1241      """Creates a `DynamicRaggedShape.Spec` corresponding to a `tf.TensorShape`.
1242
1243      It is assumed that this is a `tf.TensorShape` coming from a
1244      `tf.TensorSpec`, not from `RaggedTensor.shape`.
1245
1246      In addition to the shape, we need to know the number of row partitions,
1247      and the dtype used in the shape (tf.int32 or tf.int64).
1248
1249      Within the dimensions that are partitioned, all dimensions are assumed
1250      to be uniform.
1251
1252      Args:
1253        shape: a TensorShape.
1254        num_row_partitions: the ragged rank of the RaggedShape.
1255        dtype: the dtype of the shape (not the tensor); tf.int64 or tf.int32.
1256
1257      Returns:
1258        a DynamicRaggedShape.Spec representing a TensorShape.
1259      """
1260      if dtype != dtypes.int32 and dtype != dtypes.int64:
1261        raise ValueError("dtype must be tf.int32 or tf.int64")
1262
1263      shape = tensor_shape.as_shape(shape)
1264      if shape.rank is None:
1265        row_partitions = [
1266            RowPartitionSpec(dtype=dtype) for _ in range(num_row_partitions)
1267        ]
1268        return DynamicRaggedShape.Spec(
1269            row_partitions=row_partitions,
1270            static_inner_shape=tensor_shape.TensorShape(None),
1271            dtype=dtype)
1272
1273      if shape.rank <= 1:
1274        # Create a scalar or vector shape.
1275        if num_row_partitions:
1276          raise ValueError("num_row_partitions should be zero " +
1277                           "if shape is a scalar or vector.")
1278        return DynamicRaggedShape.Spec(
1279            row_partitions=[], static_inner_shape=shape, dtype=dtype)
1280
1281      if shape.rank <= num_row_partitions:
1282        raise ValueError("num_row_partitions must be less than rank")
1283
1284      num_elements_so_far = tensor_shape.dimension_value(shape[0])
1285      rp_specs = []
1286      for i in range(num_row_partitions):
1287        current_dim = tensor_shape.dimension_value(shape[i + 1])
1288        if current_dim is None or num_elements_so_far is None:
1289          nvals = None
1290        else:
1291          nvals = num_elements_so_far * current_dim
1292        rp_specs.append(RowPartitionSpec(
1293            nrows=num_elements_so_far,
1294            nvals=nvals,
1295            uniform_row_length=current_dim,
1296            dtype=dtype))
1297        num_elements_so_far = nvals
1298
1299      static_inner_shape = tensor_shape.TensorShape(
1300          [num_elements_so_far]) + shape[num_row_partitions + 1:]
1301      return DynamicRaggedShape.Spec(
1302          row_partitions=rp_specs,
1303          static_inner_shape=static_inner_shape,
1304          dtype=dtype)
1305
1306    @classmethod
1307    def _from_spec(
1308        cls,
1309        spec: Union["DynamicRaggedShape.Spec", ragged_tensor.RaggedTensorSpec,
1310                    tensor_spec.TensorSpec],
1311        dtype: dtypes.DType = dtypes.int64) -> "DynamicRaggedShape.Spec":
1312      """Create a TypeSpec for the shape of an object with a given TypeSpec.
1313
1314      I.e., if `x_spec = tf.type_spec_from_value(x)`, then
1315      `DynamicRaggedShape.from_spec(x_spec)` returns a TypeSpec compatible with
1316      `tf.type_spec_from_value(tf.shape(x))`.
1317
1318      >>> rt = tf.ragged.constant([[1, 2], [3], [4, 5, 6]])
1319      >>> rt_spec = tf.type_spec_from_value(rt)
1320      >>> rt_shape = DynamicRaggedShape.from_tensor(rt)
1321
1322      >>> shape_spec_1 = tf.type_spec_from_value(rt_shape)
1323      >>> shape_spec_2 = DynamicRaggedShape.Spec._from_spec(rt_spec)
1324      >>> assert shape_spec_1.is_compatible_with(shape_spec_2)
1325
1326      Args:
1327        spec: a Spec of a Tensor or RaggedTensor.
1328        dtype: the default dtype (if necessary).
1329
1330      Returns:
1331        A Spec of the shape of a Tensor or RaggedTensor.
1332
1333      """
1334      # TODO(martinz): Add StructuredTensor.Spec when its easy.
1335      if isinstance(spec, DynamicRaggedShape.Spec):
1336        return spec
1337      elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
1338        return cls._from_tensor_shape(spec.shape,
1339                                      spec.ragged_rank,
1340                                      spec.row_splits_dtype)
1341      elif isinstance(spec, tensor_spec.TensorSpec):
1342        return cls._from_tensor_shape(shape=spec.shape,
1343                                      num_row_partitions=0,
1344                                      dtype=dtype)
1345
1346    @property
1347    def dtype(self) -> dtypes.DType:
1348      return self._inner_shape.dtype
1349
1350    @property
1351    def inner_rank(self) -> Optional[int]:
1352      if self._static_inner_shape.rank is not None:
1353        return self._static_inner_shape.rank
1354      if self._inner_shape.shape.rank is None:
1355        return None
1356      return tensor_shape.dimension_value(self._inner_shape.shape[0])
1357
1358    @property
1359    def num_row_partitions(self) -> int:
1360      return len(self._row_partitions)
1361
1362    @property
1363    def rank(self) -> Optional[int]:
1364      inner_rank = self.inner_rank
1365      return None if inner_rank is None else inner_rank + self.num_row_partitions
1366
1367    def _dimension(self, index: int) -> Optional[int]:
1368      """Get the size of dimension index, if known statically."""
1369      if index == 0:
1370        if self._row_partitions:
1371          return self._row_partitions[0].nrows
1372        elif self.inner_rank is None:
1373          return None
1374        elif self.inner_rank == 0:
1375          raise ValueError("Index out of range: 0.")
1376        else:
1377          return tensor_shape.dimension_value(self._static_inner_shape[0])
1378      if index <= len(self._row_partitions):
1379        return self._row_partitions[index - 1].uniform_row_length
1380
1381      relative_index = index - self.num_row_partitions
1382
1383      if self.inner_rank is None:
1384        return None
1385      elif self.inner_rank <= relative_index:
1386        raise ValueError(f"Index out of range: {index}.")
1387      else:
1388        return tensor_shape.dimension_value(
1389            self._static_inner_shape[relative_index])
1390
1391    def _num_slices_in_dimension(self, axis: int) -> Optional[int]:
1392      """The total size of a dimension (like nvals).
1393
1394      This is a static version of DynamicRaggedShape._num_slices_in_dimension()
1395
1396      Example:
1397
1398      ```
1399      shape = DynamicRaggedShape.Spec(
1400        _row_partitions=[
1401          RowPartitionSpec(nrows=3, nvals=14, dtype=tf.int32)
1402          RowPartitionSpec(nrows=14, nvals=25, dtype=tf.int32)
1403
1404        ],
1405        _static_inner_shape=tf.TensorShape([25, 3, 4]),
1406        _inner_shape=tf.TensorSpec(tf.TensorShape([3]), dtype=tf.int32))
1407      shape._num_slices_in_dimension(0) = 3
1408      shape._num_slices_in_dimension(1) = 14
1409      shape._num_slices_in_dimension(2) = 25
1410      shape._num_slices_in_dimension(3) = 3
1411      shape._num_slices_in_dimension(4) = 4
1412      shape._num_slices_in_dimension(-2) = 3
1413      ```
1414
1415      Args:
1416        axis: the last dimension to include.
1417
1418      Returns:
1419        the number of values in a dimension.
1420      """
1421      if not isinstance(axis, int):
1422        raise TypeError("axis must be an integer")
1423      axis = array_ops.get_positive_axis(axis, self.rank, ndims_name="rank")
1424
1425      if axis == 0:
1426        return self._dimension(0)
1427      if axis <= self.num_row_partitions:
1428        # TODO(martinz): use nvals OR nrows, whichever is defined.
1429        return self._row_partitions[axis - 1].nvals
1430      remainder = axis - (self.num_row_partitions - 1)
1431      head_inner_shape = self._static_inner_shape[:remainder]
1432      return head_inner_shape.num_elements()
1433
1434    def with_dtype(self, dtype: dtypes.DType) -> "DynamicRaggedShape.Spec":
1435      """Return the same spec, but with a different DType."""
1436      new_rp_specs = [rp.with_dtype(dtype) for rp in self._row_partitions]
1437      return DynamicRaggedShape.Spec(
1438          row_partitions=new_rp_specs,
1439          static_inner_shape=self._static_inner_shape,
1440          dtype=dtype)
1441
1442    def _merge_with(
1443        self,
1444        other: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape.Spec":
1445      """Merges all information between two specs.
1446
1447      Specs are expected to represent the same information modulo
1448      num_row_partitons.
1449
1450      If the specs are of different ranks, then fail.
1451
1452      Args:
1453        other: another Spec of the same rank.
1454
1455      Returns:
1456        a Spec with the union of information.
1457      """
1458      max_num_row_partitions = max(self.num_row_partitions,
1459                                   other.num_row_partitions)
1460      a = self._with_num_row_partitions(max_num_row_partitions)
1461      b = other._with_num_row_partitions(max_num_row_partitions)
1462
1463      new_rp = [
1464          a._merge_with(b)
1465          for (a, b) in zip(a._row_partitions, b._row_partitions)
1466      ]
1467
1468      new_static_inner_shape = a._static_inner_shape.merge_with(
1469          b._static_inner_shape)
1470
1471      dtype = b.dtype if (a.dtype == dtypes.int32) else dtypes.int64
1472
1473      return DynamicRaggedShape.Spec(
1474          new_rp, new_static_inner_shape, dtype=dtype)
1475
1476    def _with_num_row_partitions(
1477        self,
1478        new_num_row_partitions: int) -> "DynamicRaggedShape.Spec":
1479      """Change the number of row partitions in the spec."""
1480      rank = self.rank
1481      if rank is None:
1482        raise ValueError(
1483            "Changing num_row_partitions with unknown rank unsupported")
1484      if new_num_row_partitions > max(rank - 1, 0):
1485        raise ValueError("Number of row partitions too large")
1486      if new_num_row_partitions < 0:
1487        raise ValueError("Number of row partitions negative")
1488      if self.num_row_partitions == new_num_row_partitions:
1489        return self
1490      elif self.num_row_partitions < new_num_row_partitions:
1491        # TODO(martinz): Consider swapping.
1492        rp_delta = new_num_row_partitions - self.num_row_partitions
1493        tail_shape = DynamicRaggedShape.Spec._from_tensor_shape(
1494            self._static_inner_shape, rp_delta, self.dtype)
1495        return DynamicRaggedShape.Spec(
1496            row_partitions=self._row_partitions + tail_shape._row_partitions,
1497            static_inner_shape=tail_shape._static_inner_shape,
1498            dtype=self.dtype)
1499      else:
1500        assert self.num_row_partitions > new_num_row_partitions
1501        new_row_partitions = self._row_partitions[:new_num_row_partitions]
1502        last_row_partition = new_row_partitions[-1]
1503        old_row_partitions = self._row_partitions[new_num_row_partitions:]
1504        new_static_inner_shape = (
1505            tensor_shape.TensorShape(
1506                [last_row_partition.nvals] +
1507                [x.uniform_row_length for x in old_row_partitions]) +
1508            self._static_inner_shape[1:])
1509        return DynamicRaggedShape.Spec(
1510            new_row_partitions, new_static_inner_shape, self.dtype)
1511
1512    def _set_rank_if_unknown(self, new_rank: int) -> "DynamicRaggedShape.Spec":
1513      """Ensures this has a known rank at least new_rank."""
1514      if new_rank is None:
1515        raise TypeError("new_rank is None, but expected int")
1516      if new_rank < 0:
1517        raise ValueError("Rank must be non-negative")
1518      current_rank = self.rank
1519      if current_rank is not None and current_rank < new_rank:
1520        raise ValueError(
1521            "Rank is {current_rank}, expected at least {new_rank}.".format(
1522                current_rank=current_rank, new_rank=new_rank))
1523
1524      if current_rank is not None:
1525        return self
1526
1527      if self._row_partitions:
1528        new_inner_rank = max(new_rank - self.num_row_partitions, 1)
1529        first_dim = self._row_partitions[-1].nvals
1530        static_inner_shape = tensor_shape.TensorShape(
1531            [first_dim] + [None] * (new_inner_rank - 1))
1532      else:
1533        static_inner_shape = tensor_shape.TensorShape([None] * new_rank)
1534
1535      return DynamicRaggedShape.Spec(
1536          row_partitions=self._row_partitions,
1537          static_inner_shape=static_inner_shape,
1538          dtype=self.dtype)
1539
1540    def _truncate(self, new_rank: int) -> "DynamicRaggedShape.Spec":
1541      """Truncate a ragged shape spec.
1542
1543      For example, if the original spec s was for a shape:
1544      [3, [4, 1], 2, 7]
1545
1546      Then truncate_dynamic_ragged_shape_spec(s, 3) is a spec for:
1547      [3, [4, 1], 2]
1548
1549      Args:
1550        new_rank: the new rank
1551
1552      Returns:
1553        A truncated DynamicRaggedShape.Spec.
1554      """
1555      if self.rank is None:
1556        return self._set_rank_if_unknown(new_rank)._truncate(new_rank)
1557
1558      if new_rank == 0:
1559        return DynamicRaggedShape.Spec._from_tensor_shape([], 0, self.dtype)
1560
1561      if new_rank == 1:
1562        vector_size = self._dimension(0)
1563        return DynamicRaggedShape.Spec._from_tensor_shape([vector_size], 0,
1564                                                          self.dtype)
1565
1566      if new_rank < self.num_row_partitions + 1:
1567        new_row_partitions = self._row_partitions[:new_rank - 1]
1568        new_static_inner_shape = tensor_shape.TensorShape(
1569            [new_row_partitions[-1].nvals])
1570        return DynamicRaggedShape.Spec(
1571            row_partitions=new_row_partitions,
1572            static_inner_shape=new_static_inner_shape,
1573            dtype=self.dtype)
1574      else:
1575        remainder = new_rank - self.num_row_partitions
1576        new_static_inner_shape = self._static_inner_shape[:remainder]
1577        return DynamicRaggedShape.Spec(
1578            row_partitions=self._row_partitions,
1579            static_inner_shape=new_static_inner_shape,
1580            dtype=self.dtype)
1581
1582    def _to_tensor_shape(self):
1583      """Get a tensor shape corresponding to this type."""
1584      alt = self
1585      if alt._static_inner_shape.rank is None:
1586        return tensor_shape.TensorShape(None)
1587      if alt._static_inner_shape.rank == 0:
1588        assert not alt._row_partitions
1589        return alt._static_inner_shape
1590      prefix = [alt._dimension(0)]
1591      prefix.extend([rp.uniform_row_length for rp in alt._row_partitions])
1592      suffix = alt._static_inner_shape[1:]
1593      return tensor_shape.TensorShape(prefix) + suffix
1594
1595
1596def broadcast_dynamic_shape(shape_x: DynamicRaggedShape,
1597                            shape_y: DynamicRaggedShape) -> DynamicRaggedShape:
1598  """Returns the shape formed by broadcasting two shapes to be compatible.
1599
1600  1. If shape_x and shape_y both have row_partitions, then fail if their dtypes
1601     don't match.
1602  2. If neither has row_partitions and they have different dtypes,
1603     go with int64.
1604  3. If one has row_partitions, go with that dtype.
1605
1606  Args:
1607    shape_x: A `DynamicRaggedShape`
1608    shape_y: A `DynamicRaggedShape`
1609
1610  Returns:
1611    A `DynamicRaggedShape`.
1612  Raises:
1613    ValueError: If `shape_x` and `shape_y` are not broadcast-compatible.
1614  """
1615  if not isinstance(shape_x, DynamicRaggedShape):
1616    raise TypeError("shape_x must be a DynamicRaggedShape")
1617  if not isinstance(shape_y, DynamicRaggedShape):
1618    raise TypeError("shape_y must be a DynamicRaggedShape")
1619
1620  return broadcast_dynamic_shape_extended(shape_x, shape_y)[0]
1621
1622
1623def broadcast_to(rt_input, shape: DynamicRaggedShape):
1624  """Broadcasts a potentially ragged tensor to a ragged shape.
1625
1626  Tiles `rt_input` as necessary to match the given shape.
1627
1628  Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
1629
1630  Args:
1631    rt_input: The potentially ragged tensor to broadcast.
1632    shape: A `DynamicRaggedShape`
1633
1634  Returns:
1635    A potentially ragged tensor whose values are taken from
1636    `rt_input`, and whose shape matches `shape`.
1637  """
1638  if not isinstance(shape, DynamicRaggedShape):
1639    raise TypeError("shape must be a DynamicRaggedShape")
1640  rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
1641  origin_shape = None
1642  if ragged_tensor.is_ragged(rt_input):
1643    if shape.num_row_partitions != 0:
1644      if rt_input.row_splits.dtype != shape.dtype:
1645        raise ValueError("Cannot coerce row_splits.dtype")
1646    else:
1647      shape = shape.with_dtype(rt_input.row_splits.dtype)
1648    origin_shape = DynamicRaggedShape.from_tensor(rt_input)
1649  else:
1650    if shape.num_row_partitions != 0:
1651      origin_shape = DynamicRaggedShape.from_tensor(rt_input, dtype=shape.dtype)
1652    else:
1653      origin_shape = DynamicRaggedShape.from_tensor(rt_input,
1654                                                    dtype=dtypes.int64)
1655      shape = shape.with_dtype(dtype=dtypes.int64)
1656
1657  broadcaster = _get_broadcaster(origin_shape, shape)
1658  return broadcaster.broadcast(rt_input)
1659
1660
1661def broadcast_dynamic_shape_extended(
1662    a: DynamicRaggedShape, b: DynamicRaggedShape
1663):  #  -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster]
1664  """Gets the smallest shape to which a and b can broadcast.
1665
1666  In order to create the smallest shape, one must also do most of the
1667  work to figure out how to transform from the shapes given. Thus, in addition
1668  to returning the shape, it also creates transformations from the
1669  original shapes to the result.
1670
1671  This is the equivalent of:
1672
1673  c = broadcast_dynamic_shape(a, b)
1674  ac = get_broadcaster(a, c)
1675  bc = get_broadcaster(b, c)
1676  return (c, ac, bc)
1677
1678  Args:
1679    a: a DynamicRaggedShape
1680    b: a DynamicRaggedShape
1681
1682  Returns:
1683    A triple of a shape and two broadcasters.
1684  """
1685  if a.row_partitions and b.row_partitions:
1686    if a.dtype != b.dtype:
1687      raise ValueError("Dtypes don't match")
1688  elif a.dtype != b.dtype:
1689    if a.row_partitions:
1690      b = b.with_dtype(a.dtype)
1691    elif b.row_partitions:
1692      a = a.with_dtype(b.dtype)
1693    else:
1694      a = a.with_dtype(dtypes.int64)
1695      b = b.with_dtype(dtypes.int64)
1696
1697  if (a.rank is None or b.rank is None):
1698    raise ValueError("Unable to broadcast: unknown rank")
1699  elif a.rank == 0:
1700    return (b, _Broadcaster(a, b, []), _get_identity_broadcaster(b))
1701  elif b.rank == 0:
1702    return (a, _get_identity_broadcaster(a), _Broadcaster(b, a, []))
1703  elif a.rank == 1 and b.rank == 1:
1704    [a_layer, b_layer,
1705     target] = _broadcast_dynamic_shape_one_layer(a.inner_shape, b.inner_shape)
1706    target_shape = DynamicRaggedShape._from_inner_shape(target)  # pylint: disable=protected-access
1707    return (target_shape, _Broadcaster(a, target_shape, [a_layer]),
1708            _Broadcaster(b, target_shape, [b_layer]))
1709
1710  if a.rank > b.rank:
1711    (c, bc, ac) = _broadcast_dynamic_shape_extended_helper(b, a)  # pylint: disable=arguments-out-of-order
1712
1713    return (c, ac, bc)
1714
1715  return _broadcast_dynamic_shape_extended_helper(a, b)
1716
1717
1718def _row_partitions_identical(shape_a, shape_b):
1719  """Returns True iff all row_partitions in shapes are identical."""
1720  return ((shape_a.num_row_partitions == shape_b.num_row_partitions) and all(
1721      a is b for a, b in zip(shape_a.row_partitions, shape_b.row_partitions)))
1722
1723
1724# TODO(martinz): Preserve shapes better (see CL/414806185)
1725@dispatch.dispatch_for_binary_elementwise_apis(ragged_tensor.RaggedOrDense,
1726                                               ragged_tensor.RaggedOrDense)
1727def ragged_binary_elementwise_op_impl(op, x, y):
1728  """Binary elementwise api handler for RaggedTensors."""
1729  x_is_ragged = ragged_tensor.is_ragged(x)
1730  y_is_ragged = ragged_tensor.is_ragged(y)
1731
1732  # Convert args to tensors.
1733  x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1734      x, preferred_dtype=(y.dtype if y_is_ragged else None))
1735  y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1736      y, preferred_dtype=x.dtype)
1737
1738  if x_is_ragged and y_is_ragged:
1739    x, y = ragged_tensor.match_row_splits_dtypes(x, y)
1740
1741  if ((x_is_ragged and y_is_ragged) or
1742      (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
1743      (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
1744    shape_x = DynamicRaggedShape.from_tensor(x)
1745    shape_y = DynamicRaggedShape.from_tensor(y)
1746    if shape_x.dtype != shape_y.dtype:
1747      if not x_is_ragged:
1748        shape_x = shape_x.with_dtype(shape_y.dtype)
1749      elif not y_is_ragged:
1750        shape_y = shape_y.with_dtype(shape_x.dtype)
1751
1752    if _row_partitions_identical(shape_x, shape_y):
1753      # At this point, both x and y must be ragged.
1754      return shape_x._add_row_partitions(  # pylint: disable=protected-access
1755          op(x.flat_values, y.flat_values), validate=False)
1756
1757    (shape_z, bcast_xz,
1758     bcast_yz) = broadcast_dynamic_shape_extended(shape_x, shape_y)
1759    x_new_flat = bcast_xz.broadcast_flat_values(x, inner_dimensions=False)
1760    y_new_flat = bcast_yz.broadcast_flat_values(y, inner_dimensions=False)
1761    z_flat = op(x_new_flat, y_new_flat)
1762    return shape_z._add_row_partitions(z_flat, validate=True)  # pylint: disable=protected-access
1763
1764  x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
1765  y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
1766  mapped_values = op(x_values, y_values)
1767  if isinstance(mapped_values, bool):
1768    return mapped_values  # Special case for tensor_equals.
1769  if ragged_tensor.is_ragged(x):
1770    return x.with_flat_values(mapped_values)
1771  else:
1772    return y.with_flat_values(mapped_values)
1773
1774
1775@dispatch.dispatch_for_binary_elementwise_assert_apis(
1776    ragged_tensor.RaggedOrDense, ragged_tensor.RaggedOrDense)
1777def ragged_binary_elementwise_assert_op_impl(op, x, y):
1778  """Binary elementwise assert api handler for RaggedTensors.
1779
1780  This handles binary assert operations for ragged tensors. Compared with
1781  `ragged_binary_elementwise_op_impl`, this handler does not compute a ragged
1782  tensor as output. Instead, it applies the assert operation `op` to input
1783  tensors based on their ragged shapes and flat_values, and returns the result
1784  of the assertion operation.
1785
1786  Args:
1787    op: a binary assert operation on Tensors.
1788    x: something that can be coerced to a Tensor or RaggedTensor.
1789    y: something that can be coerced to a Tensor or RaggedTensor.
1790
1791  Returns:
1792    the result of the assertion operation.
1793
1794  """
1795  x_is_ragged = ragged_tensor.is_ragged(x)
1796  y_is_ragged = ragged_tensor.is_ragged(y)
1797
1798  # Convert args to tensors.
1799  x = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1800      x, preferred_dtype=(y.dtype if y_is_ragged else None))
1801  y = ragged_tensor.convert_to_tensor_or_ragged_tensor(
1802      y, preferred_dtype=x.dtype)
1803
1804  if x_is_ragged and y_is_ragged:
1805    x, y = ragged_tensor.match_row_splits_dtypes(x, y)
1806
1807  if ((x_is_ragged and y_is_ragged) or
1808      (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
1809      (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
1810    shape_x = DynamicRaggedShape.from_tensor(x)
1811    shape_y = DynamicRaggedShape.from_tensor(y)
1812    if shape_x.dtype != shape_y.dtype:
1813      if not x_is_ragged:
1814        shape_x = shape_x.with_dtype(shape_y.dtype)
1815      elif not y_is_ragged:
1816        shape_y = shape_y.with_dtype(shape_x.dtype)
1817
1818    if _row_partitions_identical(shape_x, shape_y):
1819      # At this point, both x and y must be ragged.
1820      return op(x.flat_values, y.flat_values)
1821
1822    (_, bcast_xz, bcast_yz) = broadcast_dynamic_shape_extended(shape_x, shape_y)
1823    x_new_flat = bcast_xz.broadcast_flat_values(x, inner_dimensions=False)
1824    y_new_flat = bcast_yz.broadcast_flat_values(y, inner_dimensions=False)
1825    return op(x_new_flat, y_new_flat)
1826
1827  x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
1828  y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
1829  return op(x_values, y_values)
1830
1831
1832def _find_dtype_helper(value, preferred):
1833  """Helper for _find_dtype."""
1834  if preferred is not None:
1835    return preferred
1836  elif isinstance(value, RowPartition):
1837    return value.dtype
1838  elif isinstance(value, dtypes.DType):
1839    return value
1840  elif isinstance(value, int):
1841    return None
1842  elif isinstance(value, list):
1843    return None
1844  elif isinstance(value, tuple):
1845    return None
1846  elif isinstance(value, core.Tensor):
1847    return value.dtype
1848  return value.dtype
1849
1850
1851def _find_dtype(value, preferred):
1852  """Returns the preferred dtype of value or preferred if preferred != None.
1853
1854  This is used as an operator to pass over multiple objects in decreasing order
1855  of priority until there is a preferred dtype for one. For example, if you were
1856  adding three tensor-ish things (some tensors, some lists), and needed a
1857  preferred dtype, you could use this as:
1858
1859  def adding(a, b, c, dtype = None):
1860    dtype = _find_dtype(a, dtype)
1861    dtype = _find_dtype(b, dtype)
1862    dtype = _find_dtype(c, dtype)
1863    if dtype is None:
1864      dtype = tf.float32
1865    ...Code continues here...
1866
1867  Args:
1868    value: a list, value, RowPartition, or tensor.
1869    preferred: a given dtype. If not None, this will be returned.
1870
1871  Returns:
1872    an optional dtype.
1873  """
1874  result = _find_dtype_helper(value, preferred)
1875  if (result == dtypes.int64 or result == dtypes.int32 or result is None):
1876    return result
1877  raise ValueError("Illegal dtype: " + str(result))
1878
1879
1880def _find_dtype_iterable(
1881    iterable: Iterable[Any],
1882    dtype: Optional[dtypes.DType]) -> Optional[dtypes.DType]:
1883  """Find the preferred dtype of a list of objects.
1884
1885  This will go over the iterable, and use the first object with a preferred
1886  dtype. The dtype passed has highest priority if it is not None.
1887
1888  Args:
1889    iterable: an iterable with things that might have a dtype.
1890    dtype: an overriding dtype, or None.
1891
1892  Returns:
1893    an optional dtype.
1894  """
1895  if dtype is not None:
1896    return dtype
1897  for x in iterable:
1898    dtype = _find_dtype(x, dtype)
1899  return dtype
1900
1901
1902class _LayerBroadcaster(abc.ABC):
1903  """A broadcaster of a single layer.
1904
1905  Although this class does not literally contain a gather_index, the reference
1906  implementation is defined through a gather_index. Thus, any subclasses should
1907  first define the gather_index property. Other functions can be overridden
1908  for optimization, but it should not change the behavior.
1909  """
1910
1911  @property
1912  @abc.abstractmethod
1913  def gather_index(self):
1914    """Returns a 1D tensor.
1915
1916    The size of the 1D tensor is equal to the destination size.
1917
1918    The ith element of the result is the index of the source of the ith element.
1919    """
1920    pass
1921
1922  @property
1923  def dtype(self):
1924    """Returns the dtype of the broadcast."""
1925    return self.gather_index.dtype
1926
1927  @abc.abstractmethod
1928  def with_dtype(self, dtype):
1929    """Returns an identical _LayerBroadcaster with a different dtype."""
1930    pass
1931
1932  def __repr__(self):
1933    return str(self.gather_index)
1934
1935  @classmethod
1936  def from_gather_index(cls, gather_index):
1937    """Create a broadcaster from a gather_index."""
1938    return _GatherLayerBroadcaster(gather_index)
1939
1940  @classmethod
1941  def first_layer(cls, nrows_source, nrows_target):
1942    """Create a broadcaster from a gather_index."""
1943    gather_index = _first_layer_gather_index(nrows_source, nrows_target)
1944    return _LayerBroadcaster.from_gather_index(gather_index)
1945
1946  @classmethod
1947  def get_singleton_broadcaster(cls, target_size):
1948    """Broadcast from 1 element to target_size elements."""
1949    return _LayerBroadcaster.from_gather_index(
1950        array_ops.zeros(target_size, dtype=target_size.dtype))
1951
1952  @abc.abstractmethod
1953  def with_dependencies(self, checks):
1954    """Add dependencies to a _LayerBroadcaster.
1955
1956    Args:
1957      checks: a list of ops that need to be run before any tensors from the
1958        Broadcaster are used.
1959
1960    Returns:
1961      a copy of this _LayerBroadcaster with dependencies added.
1962    """
1963    pass
1964
1965  @classmethod
1966  def get_identity_broadcaster(cls, nvals, dtype=None):
1967    """Create an identity broadcaster.
1968
1969    TODO(martinz): an identity broadcaster can be far more efficient than a
1970    generic broadcaster. Add an optimized implementation.
1971    Args:
1972      nvals: the number of values for the broadcaster.
1973      dtype: the dtype of the broadcaster, or None to use the dtype of nvals.
1974    Returns:
1975      an identity broadcaster from [0....nvals-1] to [0...nvals-1]
1976    """
1977    return _GatherLayerBroadcaster(math_ops.range(nvals, dtype=dtype))
1978
1979  def broadcast_tensor(self, tensor):
1980    """Broadcast from a dense tensor.
1981
1982    It is assumed that the first axis of the dense tensor is indexed by the
1983    source shape, and at the end, the first axis of the dense tensor is
1984    indexed by the destination shape.
1985
1986    Args:
1987      tensor: a dense tensor.
1988
1989    Returns:
1990      A dense tensor.
1991    """
1992    return array_ops.gather(tensor, self.gather_index)
1993
1994  def dest_nrows(self):
1995    """Return the number of rows in the resulting gather, or None if tiling."""
1996    return math_ops.cast(
1997        array_ops.shape(self.gather_index)[0], dtype=self.dtype)
1998
1999  def broadcast_row_partition(self, rp):
2000    """Return a new shape where the rows are broadcasted.
2001
2002        *--self--->*
2003        |          |
2004        rp       result
2005        |          |
2006        V          V
2007        *--------->*
2008
2009    This is equivalent to:
2010      return RowPartition.from_row_lengths(self.broadcast(rp.row_lengths()))
2011
2012    However, if the shape has uniform row length, then that property is
2013    maintained.
2014
2015    Args:
2016      rp: a row partition.
2017
2018    Returns:
2019      a RowPartition representing a broadcast version of this row partition.
2020    """
2021    if not rp.is_uniform():
2022      return RowPartition.from_row_lengths(
2023          self.broadcast_tensor(rp.row_lengths()))
2024    else:
2025      return RowPartition.from_uniform_row_length(
2026          rp.uniform_row_length(),
2027          nvals=rp.uniform_row_length() * self.dest_nrows(),
2028          nrows=self.dest_nrows())
2029
2030  def next_layer(self, original_rp, broadcast_rp):
2031    r"""Create the next layer gather_index whether or not a broadcast happens.
2032
2033       *---------self------->*
2034       |                     |
2035    original_rp           broadcast_rp
2036       |                     |
2037      \|/                   \|/
2038       *--next_broadcaster-->*
2039    Args:
2040      original_rp: the original row partition.
2041      broadcast_rp: the target row partition.
2042
2043    Returns:
2044      the gather_index for next_broadcaster.
2045
2046    """
2047    gather_index = _next_layer_gather_index(self, original_rp, broadcast_rp)
2048    return _LayerBroadcaster.from_gather_index(gather_index)
2049
2050
2051class _GatherLayerBroadcaster(_LayerBroadcaster):
2052  """Implements _LayerBroadcaster with an explicit gather_index.
2053
2054  For example, suppose that the source shape is:
2055  [*],[*,*]
2056  And the target shape is:
2057  [*],[*,*],[*],[*,*]
2058  Then, this can be represented with a map:
2059  [0,1,2,0,1,2]
2060
2061  """
2062
2063  def __init__(self, gather_index):
2064    gather_index = ops.convert_to_tensor(gather_index)
2065    if (gather_index.dtype != dtypes.int64 and
2066        gather_index.dtype != dtypes.int32):
2067      raise ValueError("gather_index must be int64 or int32")
2068    self._gather_index = gather_index
2069
2070  @property
2071  def gather_index(self):
2072    return self._gather_index
2073
2074  def with_dtype(self, dtype):
2075    return _GatherLayerBroadcaster(math_ops.cast(self._gather_index, dtype))
2076
2077  def with_dependencies(self, checks):
2078    new_gather_index = control_flow_ops.with_dependencies(
2079        checks, self._gather_index)
2080    return _GatherLayerBroadcaster(new_gather_index)
2081
2082
2083class _Broadcaster:
2084  """A _Broadcaster represents a transformation from one shape to another.
2085
2086  It provides a transform for each axis of the source shape to the
2087  corresponding axis of the destination shape.
2088
2089  """
2090
2091  def __init__(self,
2092               source_shape,
2093               target_shape,
2094               layer_broadcasters,
2095               dtype=None):
2096    """Create a broadcaster.
2097
2098    Do not call directly.
2099    The source_shape, target_shape, and layer_broadcasters are converted
2100    to have the same dtype.
2101
2102    Note: source_shape.rank and target_shape.rank must be known.
2103    Args:
2104      source_shape: the source DynamicRaggedShape
2105      target_shape: the target DynamicRaggedShape
2106      layer_broadcasters: List[_LayerBroadcaster] of length source_shape.rank.
2107      dtype: the preferred dtype of the broadcaster.
2108
2109    Raises:
2110      TypeError: if the input types don't match.
2111    """
2112    if not isinstance(source_shape, DynamicRaggedShape):
2113      raise TypeError("source_shape is not a DynamicRaggedShape")
2114    if not isinstance(target_shape, DynamicRaggedShape):
2115      raise TypeError("target_shape is not a DynamicRaggedShape")
2116    if not isinstance(layer_broadcasters, list):
2117      raise TypeError("layer_broadcasters not a list: " +
2118                      str(layer_broadcasters))
2119    for bc in layer_broadcasters:
2120      if not isinstance(bc, _LayerBroadcaster):
2121        raise TypeError("Not a LayerBroadcaster: " + str(bc))
2122
2123    dtype = _find_dtype(source_shape, dtype)
2124    dtype = _find_dtype(target_shape, dtype)
2125    dtype = _find_dtype_iterable(layer_broadcasters, dtype)
2126    dtype = _find_dtype(dtypes.int64, dtype)
2127    self._source_shape = source_shape.with_dtype(dtype)
2128    self._target_shape = target_shape.with_dtype(dtype)
2129    self._layer_broadcasters = [x.with_dtype(dtype) for x in layer_broadcasters]
2130
2131  def __repr__(self):
2132    return ("{src_shape:" + str(self._source_shape) + ", target_shape:" +
2133            str(self._target_shape) + " layer_broadcasters: " +
2134            str(self._layer_broadcasters) + "}")
2135
2136  def with_dtype(self, dtype):
2137    """Return a copy of this Broadcaster with a different dtype."""
2138    return _Broadcaster(self._source_shape, self._target_shape,
2139                        self._layer_broadcasters, dtype)
2140
2141  @property
2142  def source_shape(self):
2143    return self._source_shape
2144
2145  @property
2146  def target_shape(self):
2147    return self._target_shape
2148
2149  @property
2150  def dtype(self):
2151    return self._source_shape.dtype
2152
2153  def _target_inner_shape_int32(self):
2154    new_inner_shape = self.target_shape.inner_shape
2155    if new_inner_shape.dtype == dtypes.int64:
2156      new_inner_shape = math_ops.cast(new_inner_shape, dtype=dtypes.int32)
2157    return new_inner_shape
2158
2159  # pylint:disable=protected-access
2160  def broadcast_flat_values(self, rt, inner_dimensions=True):
2161    """flat_values of a ragged tensor broadcast to target_shape.
2162
2163    If inner_dimensions==True, then the result is a dense tensor with shape
2164    target_shape.inner_shape, the flat values of the broadcasted shape.
2165
2166    If you add target_shape.row_partitions, you will get the full broadcasted
2167    shape.
2168
2169    If inner_dimensions==False, the result is a dense tensor that satsifies
2170    certain properties:
2171    1. broadcast_to(result, target_shape.inner_shape) will give the result
2172       if inner_dimensions==True.
2173    2. Either (a) (result.rank < target_shape.inner_rank)
2174       or (b) (result.shape[0] == target_shape.inner_shape[0]).
2175    3. result.rank = min(target_shape.inner_rank, rt.rank)
2176    4. For i < target_shape.inner_rank - 1, and i < rt.rank,
2177       and if rt.shape[-i]!=1, then result.shape[-i]=target_shape[-i].
2178    Args:
2179      rt: a ragged or dense tensor.
2180      inner_dimensions: if true, broadcast the inner dimensions as well.
2181
2182    Returns:
2183      a dense tensor
2184    """
2185    if ragged_tensor.is_ragged(rt):
2186      rt = rt.flat_values
2187    # If rt was a regular tensor, it is its own flat_values.
2188    if self.target_shape.rank == 0:
2189      return rt
2190    inner_rank = self.target_shape.inner_rank
2191    if inner_rank > self._source_shape.rank:
2192      # The dense rank is larger than the whole shape. So, we make the shape
2193      # dense.
2194      if self.source_shape.num_row_partitions > 0:
2195        rt = array_ops.reshape(
2196            rt, self.source_shape._alt_inner_shape(self.source_shape.rank))
2197      # rt.rank == self._source_shape.rank < inner_rank
2198      # Here, property 2a holds.
2199      if inner_dimensions:
2200        return array_ops.broadcast_to(rt, self._target_inner_shape_int32())
2201      return rt
2202    else:
2203      if self._source_shape.inner_rank != inner_rank:
2204        rt = array_ops.reshape(rt,
2205                               self._source_shape._alt_inner_shape(inner_rank))  # pylint:disable=protected-access
2206      # After the reshape, rt is flat_values with inner_rank.
2207      flat_broadcaster = self._layer_broadcasters[-inner_rank]
2208      rt = flat_broadcaster.broadcast_tensor(rt)
2209      # Here, property 2b holds.
2210      if inner_dimensions:
2211        rt = array_ops.broadcast_to(rt, self._target_inner_shape_int32())
2212      return rt
2213
2214  def broadcast(self, rt):
2215    """Broadcast a tensor of source_shape to target_shape."""
2216    flat_values = self.broadcast_flat_values(rt)
2217    return self.target_shape._add_row_partitions(flat_values)  # pylint:disable=protected-access
2218
2219
2220def _get_layer_broadcasters_from_rps(zero_broadcaster, source_rps, target_rps):
2221  """Get LayerBroadcasters from RowPartitions.
2222
2223           *--zero_broadcaster->*
2224           |                    |
2225         source_rps[0]     target_rps[0]
2226           |                    |
2227           V                    V
2228           *---result[1]------->*
2229           |                    |
2230         source_rps[1]     target_rps[1]
2231           |                    |
2232           V                    V
2233           *---result[2]------->*
2234                  .
2235                  .
2236                  .
2237           *---result[k-1]----->*
2238           |                    |
2239         source_rps[k]     target_rps[k]
2240           |                    |
2241           V                    V
2242           *---result[k]------->*
2243
2244  Note: result[0] = zero_broadcaster
2245
2246  Args:
2247    zero_broadcaster: a broadcaster between the source and target row
2248      partitions' rows, and equal to result[0].
2249    source_rps: source row partitions.
2250    target_rps: target row partitions (same length as source_rps).
2251
2252  Returns:
2253    result: a list of LayerBroadcasters.
2254  """
2255  if not isinstance(zero_broadcaster, _LayerBroadcaster):
2256    raise TypeError("Not a _LayerBroadcaster: " + str(zero_broadcaster))
2257  assert len(source_rps) == len(target_rps)
2258  if not source_rps:
2259    return [zero_broadcaster]
2260  next_broadcaster = zero_broadcaster.next_layer(source_rps[0], target_rps[0])
2261  tail_broadcasters = _get_layer_broadcasters_from_rps(next_broadcaster,
2262                                                       source_rps[1:],
2263                                                       target_rps[1:])
2264  return [zero_broadcaster] + tail_broadcasters
2265
2266
2267def _get_broadcaster(source_shape, target_shape):
2268  """Get a _Broadcaster from source_shape to target_shape."""
2269  if source_shape.dtype != target_shape.dtype:
2270    raise ValueError("The source and target row_split dtypes should be equal")
2271
2272  if (source_shape.rank is None or target_shape.rank is None):
2273    raise ValueError("Rank of source and target must be statically known")
2274  elif source_shape.rank > target_shape.rank:
2275    raise ValueError("Cannot broadcast to a shape with smaller rank")
2276  elif source_shape.rank == 0:
2277    return _Broadcaster(source_shape, target_shape, [])
2278  elif target_shape.rank == 1:
2279    assert source_shape.rank == 1
2280    layer = _LayerBroadcaster.first_layer(source_shape.inner_shape[0],
2281                                          target_shape.inner_shape[0])
2282    return _Broadcaster(source_shape, target_shape, [layer])
2283
2284  assert source_shape.rank <= target_shape.rank
2285  assert target_shape.rank >= 2
2286  assert source_shape.rank >= 1
2287
2288  source_rps = source_shape._as_row_partitions()  # pylint: disable=protected-access
2289
2290  target_rps = target_shape._as_row_partitions()  # pylint: disable=protected-access
2291
2292  assert len(target_rps) >= 1
2293  assert len(source_rps) <= len(target_rps)
2294  source_nrows = source_shape[0]
2295  if len(source_rps) < len(target_rps):
2296    # Note: this includes the case where len(source_rps)==0.
2297    # Here we begin at -1, one dimension before source_rps[0].
2298    # neg_one_source_rp  | neg_one_target_rp=target_rps[-(len(source_rps)+1)]
2299    # source_rps[0]      | target_rps[-len(source_rps)]
2300    # source_rps[1]      | target_rps[1-len(source_rps)]
2301    # ...                | ...
2302    # source_rps[-1]     | target_rps[-1]
2303    neg_one_source_rp = RowPartition.from_uniform_row_length(
2304        uniform_row_length=source_nrows, nrows=1, nvals=source_nrows)
2305    neg_one_target_rp = target_rps[-(len(source_rps) + 1)]
2306    neg_one_broadcaster = _LayerBroadcaster.get_singleton_broadcaster(
2307        neg_one_target_rp.nrows())
2308    zeroth_broadcaster = neg_one_broadcaster.next_layer(neg_one_source_rp,
2309                                                        neg_one_target_rp)
2310    target_rps_tail = target_rps[-len(source_rps):] if len(
2311        source_rps) >= 1 else []
2312
2313    layers = _get_layer_broadcasters_from_rps(zeroth_broadcaster, source_rps,
2314                                              target_rps_tail)
2315    return _Broadcaster(source_shape, target_shape, layers)
2316  else:
2317    assert len(target_rps) == len(source_rps)
2318    zeroth_broadcaster = _LayerBroadcaster.first_layer(source_rps[0].nrows(),
2319                                                       target_rps[0].nrows())
2320    layers = _get_layer_broadcasters_from_rps(zeroth_broadcaster, source_rps,
2321                                              target_rps)
2322
2323    return _Broadcaster(source_shape, target_shape, layers)
2324
2325
2326def _get_identity_broadcaster(shape):
2327  """Gets a Broadcaster for two identical shapes."""
2328  if shape.rank is None:
2329    raise ValueError("Shape must have a defined rank")
2330  layers = [
2331      _LayerBroadcaster.get_identity_broadcaster(
2332          shape._num_slices_in_dimension(i)) for i in range(shape.rank)  # pylint: disable=protected-access
2333  ]
2334  return _Broadcaster(shape, shape, layers)
2335
2336
2337def _broadcast_dynamic_shape_one_layer(a, b):
2338  """Broadcast two vectors, given their shapes.
2339
2340  Args:
2341    a: the number of rows in a.
2342    b: the number of rows in b.
2343
2344  Returns:
2345    (layer_a, layer_b, target_shape)
2346    layer_a is a _LayerBroadcaster from a to the target_shape.
2347    layer_b is a _LayerBroadcaster from b to the target_shape.
2348    target_shape is the target_shape
2349
2350  Raises:
2351    InvalidArgumentError if the shapes are not consistent.
2352  """
2353  a_0 = a[0]
2354  b_0 = b[0]
2355
2356  def broadcast_from_a():
2357    # Assumes a_0 == 1
2358    a_layer = array_ops.zeros(b_0, dtype=b_0.dtype)
2359    b_layer = math_ops.range(b_0)
2360    target = b
2361    return [a_layer, b_layer, target]
2362
2363  a_static = tensor_util.constant_value(a)
2364  if a_static is not None and a_static[0] == 1:
2365    [a_gi, b_gi, target] = broadcast_from_a()
2366    a_layer = _LayerBroadcaster.from_gather_index(a_gi)
2367    b_layer = _LayerBroadcaster.from_gather_index(b_gi)
2368    return [a_layer, b_layer, target]
2369
2370  def broadcast_from_b():
2371    # Assumes b_0 == 1
2372    a_layer = math_ops.range(a_0)
2373    b_layer = array_ops.zeros(a_0, dtype=a_0.dtype)
2374    target = a
2375    return [a_layer, b_layer, target]
2376
2377  b_static = tensor_util.constant_value(b)
2378  if b_static is not None and b_static[0] == 1:
2379    [a_gi, b_gi, target] = broadcast_from_b()
2380    a_layer = _LayerBroadcaster.from_gather_index(a_gi)
2381    b_layer = _LayerBroadcaster.from_gather_index(b_gi)
2382    return [a_layer, b_layer, target]
2383
2384  def broadcast_noop():
2385    # Assumes a_0 == 1
2386    a_layer = math_ops.range(a_0)
2387    b_layer = math_ops.range(b_0)
2388    target = b
2389    return [a_layer, b_layer, target]
2390
2391  can_broadcast_from_a = math_ops.equal(a_0, 1)
2392  can_broadcast_from_b = math_ops.equal(b_0, 1)
2393
2394  def broadcast_not_from_a():
2395    return control_flow_ops.cond(
2396        can_broadcast_from_b, true_fn=broadcast_from_b, false_fn=broadcast_noop)
2397
2398  nrows_equal = math_ops.equal(a_0, b_0)
2399  can_broadcast = math_ops.logical_or(
2400      can_broadcast_from_a,
2401      math_ops.logical_or(can_broadcast_from_b, nrows_equal))
2402
2403  check_can_broadcast = check_ops.assert_equal(
2404      can_broadcast, True, message="Cannot broadcast")
2405
2406  results = control_flow_ops.cond(
2407      can_broadcast_from_a,
2408      true_fn=broadcast_from_a,
2409      false_fn=broadcast_not_from_a)
2410
2411  results = [
2412      control_flow_ops.with_dependencies([check_can_broadcast], x)
2413      for x in results
2414  ]
2415  [a_gi, b_gi, target] = results
2416  a_layer = _LayerBroadcaster.from_gather_index(a_gi)
2417  b_layer = _LayerBroadcaster.from_gather_index(b_gi)
2418  return [a_layer, b_layer, target]
2419
2420
2421def _broadcast_dynamic_shape_first_layer(a_0, b_0):
2422  """Broadcast the first layer of two dynamic shapes given the dimensions.
2423
2424  Args:
2425    a_0: the number of rows in a.
2426    b_0: the number of rows in b.
2427
2428  Returns:
2429    (use_a, layer_a, layer_b)
2430    where use_a is true if the target provably equals a, false otherwise.
2431    layer_a is a _LayerBroadcaster from a to the target.
2432    layer_b is a _LayerBroadcaster from b to the target.
2433  """
2434  def broadcast_from_a():
2435    # Assumes a_0 == 1
2436    a_layer = array_ops.zeros(b_0, dtype=b_0.dtype)
2437    b_layer = math_ops.range(b_0)
2438    return [a_layer, b_layer]
2439
2440  static_a_0 = tensor_util.constant_value(a_0)
2441  static_b_0 = tensor_util.constant_value(b_0)
2442  if static_a_0 is not None:
2443    if static_a_0 == static_b_0:
2444      id_broadcaster = _LayerBroadcaster.get_identity_broadcaster(
2445          static_a_0, dtype=a_0.dtype)
2446      return [id_broadcaster, id_broadcaster]
2447    elif static_a_0 == 1:
2448      return [
2449          _LayerBroadcaster.get_singleton_broadcaster(b_0),
2450          _LayerBroadcaster.get_identity_broadcaster(b_0)
2451      ]
2452
2453  if static_b_0 == 1:
2454    return [
2455        _LayerBroadcaster.get_identity_broadcaster(a_0),
2456        _LayerBroadcaster.get_singleton_broadcaster(a_0)
2457    ]
2458
2459  def broadcast_from_b():
2460    # Assumes b_0 == 1
2461    a_layer = math_ops.range(a_0)
2462    b_layer = array_ops.zeros(a_0, dtype=a_0.dtype)
2463    return [a_layer, b_layer]
2464
2465  def broadcast_noop():
2466    # Assumes a_0 == b_0
2467    a_layer = math_ops.range(a_0)
2468    b_layer = math_ops.range(b_0)
2469    return [a_layer, b_layer]
2470
2471  can_broadcast_from_a = math_ops.equal(a_0, constant_op.constant(1, a_0.dtype))
2472  can_broadcast_from_b = math_ops.equal(b_0, constant_op.constant(1, b_0.dtype))
2473
2474  def broadcast_not_from_a():
2475    return control_flow_ops.cond(
2476        can_broadcast_from_b, true_fn=broadcast_from_b, false_fn=broadcast_noop)
2477
2478  # Ideally, this would only block control flow on broadcast_noop, but
2479  # the control flow doesn't seem to work.
2480  can_broadcast = math_ops.logical_or(
2481      math_ops.logical_or(can_broadcast_from_a, can_broadcast_from_b),
2482      math_ops.equal(a_0, b_0))
2483
2484  result = control_flow_ops.cond(
2485      can_broadcast_from_a,
2486      true_fn=broadcast_from_a,
2487      false_fn=broadcast_not_from_a)
2488
2489  return [
2490      _LayerBroadcaster.from_gather_index(
2491          control_flow_ops.with_dependencies(
2492              [check_ops.assert_equal(can_broadcast, True)], x)) for x in result
2493  ]
2494
2495
2496def _broadcast_half(
2497    ac_0: _LayerBroadcaster,
2498    a_1: RowPartition) -> Tuple[_LayerBroadcaster, RowPartition]:
2499  """Does a NOOP broadcast of a_1.
2500
2501      *-ac_0-->*
2502      |        |
2503     a_1      c_1
2504      |        |
2505      V        V
2506      *-ac_1-->*
2507
2508  Note that by definition this cannot fail: there is always a well-defined
2509  NOOP broadcast. This is usually intended as half of broadcasting two shapes
2510  together.
2511  Args:
2512    ac_0: previous LayerBroadcaster
2513    a_1: previous RowPartition
2514
2515  Returns:
2516    [ac_1, c_1] where ac_1 is the next LayerBroadcaster, and c_1 is the
2517    broadcast RowPartition
2518  """
2519  c_1 = ac_0.broadcast_row_partition(a_1)
2520  old_value_rowids = array_ops.gather(ac_0.gather_index, c_1.value_rowids())
2521  old_row_starts = array_ops.gather(a_1.row_splits(), old_value_rowids)
2522  gather_index = old_row_starts + c_1.offsets_in_rows()
2523  return [_LayerBroadcaster.from_gather_index(gather_index), c_1]
2524
2525
2526def _broadcast_dynamic_shape_next_layer_half_ragged(
2527    ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition,
2528    b_1: RowPartition
2529) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]:
2530  r"""Broadcast target and next layer broadcaster of two dynamic shapes.
2531
2532  a_1 is uniform, and b_1 is ragged.
2533     *--ac_0-->*<--bc_0--*
2534     |         |         |
2535    a_1       c_1       b_1
2536     |         |         |
2537     V         V         V
2538     *--ac_1-->*<--bc_1--*
2539
2540  Args:
2541    ac_0: _LayerBroadcaster from a to c in the previous layer.
2542    bc_0: _LayerBroadcaster from b to c in the previous layer.
2543    a_1: a uniform RowPartition for the next layer of a.
2544    b_1: a ragged RowPartition for the next layer of b.
2545
2546  Returns:
2547    (c_1, ac_1, bc_1)
2548    c_1: a RowPartition for the next layer of the dynamic shape.
2549    ac_1: _LayerBroadcaster from a to c in the next layer.
2550    bc_1: _LayerBroadcaster from b to c in the next layer.
2551  """
2552  if not isinstance(ac_0, _LayerBroadcaster):
2553    raise TypeError("ac_0 should be a _LayerBroadcaster")
2554  if not isinstance(bc_0, _LayerBroadcaster):
2555    raise TypeError("bc_0 should be a _LayerBroadcaster")
2556  if not isinstance(a_1, RowPartition):
2557    raise TypeError("a_1 should be a RowPartition")
2558  if not isinstance(b_1, RowPartition):
2559    raise TypeError("b_1 should be a RowPartition")
2560
2561  assert a_1.is_uniform()
2562  assert not b_1.is_uniform()
2563
2564  static_a_1 = tensor_util.constant_value(a_1.uniform_row_length())
2565  if static_a_1 == 1:
2566    [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2567    ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids())
2568    c_1 = RowPartition.from_row_splits(c_1b.row_splits())
2569    ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index)
2570    bc_1 = _LayerBroadcaster.from_gather_index(bc_1.gather_index)
2571    return [c_1, ac_1, bc_1]
2572
2573  def broadcast_noop():
2574    # The sides must be "equal".
2575    [ac_1, c_1a] = _broadcast_half(ac_0, a_1)
2576    [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2577    checks = [check_ops.assert_equal(c_1a.row_splits(), c_1b.row_splits())]
2578    return [
2579        control_flow_ops.with_dependencies(checks, x)
2580        for x in [a_1.row_splits(), ac_1.gather_index, bc_1.gather_index]
2581    ]
2582
2583  def broadcast_a():
2584    [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2585    ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids())
2586    return [
2587        c_1b.row_splits(),
2588        ac_1_gather_index,
2589        bc_1.gather_index,
2590    ]
2591
2592  can_broadcast_a = math_ops.equal(a_1.uniform_row_length(), 1)
2593
2594  [c_1_row_splits, ac_1_gather_index,
2595   bc_1_gather_index] = control_flow_ops.cond(
2596       can_broadcast_a, true_fn=broadcast_a, false_fn=broadcast_noop)
2597
2598  c_1 = RowPartition.from_row_splits(c_1_row_splits)
2599  ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index)
2600  bc_1 = _LayerBroadcaster.from_gather_index(bc_1_gather_index)
2601  return [c_1, ac_1, bc_1]
2602
2603
2604def _broadcast_dynamic_shape_next_layer_both_uniform(
2605    ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition,
2606    b_1: RowPartition
2607) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]:
2608  r"""Broadcast target and next layer broadcaster of two uniform dynamic shapes.
2609
2610     *--ac_0-->*<--bc_0--*
2611     |         |         |
2612    a_1       c_1       b_1
2613     |         |         |
2614     V         V         V
2615     *--ac_1-->*<--bc_1--*
2616
2617  Args:
2618    ac_0: _LayerBroadcaster from a to c in the previous layer.
2619    bc_0: _LayerBroadcaster from b to c in the previous layer.
2620    a_1: a RowPartition for the next layer of a.
2621    b_1: a RowPartition for the next layer of b.
2622
2623  Returns:
2624    (c_1, ac_1, bc_1)
2625    c_1: a RowPartition for the next layer of the dynamic shape.
2626    ac_1: _LayerBroadcaster from a to c in the next layer.
2627    bc_1: _LayerBroadcaster from b to c in the next layer.
2628  """
2629  if not isinstance(ac_0, _LayerBroadcaster):
2630    raise TypeError("ac_0 should be a _LayerBroadcaster")
2631  if not isinstance(bc_0, _LayerBroadcaster):
2632    raise TypeError("bc_0 should be a _LayerBroadcaster")
2633  if not isinstance(a_1, RowPartition):
2634    raise TypeError("a_1 should be a RowPartition")
2635  if not isinstance(b_1, RowPartition):
2636    raise TypeError("b_1 should be a RowPartition")
2637  assert a_1.is_uniform()
2638  assert b_1.is_uniform()
2639
2640  static_a_1 = tensor_util.constant_value(a_1.uniform_row_length())
2641  static_b_1 = tensor_util.constant_value(b_1.uniform_row_length())
2642
2643  if static_a_1 is not None:
2644    if static_a_1 == static_b_1:
2645      # Here, this dimension is the same, but we may have to broadcast previous
2646      # dimensions.
2647      [ac_1, _] = _broadcast_half(ac_0, a_1)
2648      [bc_1, _] = _broadcast_half(bc_0, b_1)
2649      c_1 = RowPartition.from_uniform_row_length(
2650          static_a_1,
2651          nrows=ac_0.dest_nrows())
2652      return [c_1, ac_1, bc_1]
2653    elif static_a_1 == 1:
2654      [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2655      ac_1 = _LayerBroadcaster.from_gather_index(
2656          array_ops.gather(ac_0.gather_index, c_1b.value_rowids()))
2657      c_1 = RowPartition.from_uniform_row_length(
2658          b_1.uniform_row_length(),
2659          nrows=bc_0.dest_nrows())
2660      return [c_1, ac_1, bc_1]
2661
2662  if static_b_1 == 1:
2663    [ac_1, c_1a] = _broadcast_half(ac_0, a_1)
2664    bc_1 = _LayerBroadcaster.from_gather_index(
2665        array_ops.gather(bc_0.gather_index, c_1a.value_rowids()))
2666    c_1 = RowPartition.from_uniform_row_length(
2667        a_1.uniform_row_length(),
2668        nrows=ac_0.dest_nrows())
2669    return [c_1, ac_1, bc_1]
2670
2671  def broadcast_noop():
2672    # Assumes a_1.uniform_row_length() == b_1.uniform_row_length()
2673    # Both sides broadcast to a single shape.
2674    [ac_1, _] = _broadcast_half(ac_0, a_1)
2675    [bc_1, _] = _broadcast_half(bc_0, b_1)
2676    return [a_1.uniform_row_length(), ac_1.gather_index, bc_1.gather_index]
2677
2678  def broadcast_a():
2679    [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2680    ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids())
2681    return [
2682        b_1.uniform_row_length(),
2683        ac_1_gather_index,
2684        bc_1.gather_index,
2685    ]
2686
2687  def broadcast_b():
2688    [ac_1, c_1a] = _broadcast_half(ac_0, a_1)
2689    bc_1_gather_index = array_ops.gather(bc_0.gather_index, c_1a.value_rowids())
2690    return [a_1.uniform_row_length(), ac_1.gather_index, bc_1_gather_index]
2691
2692  can_broadcast_b = math_ops.equal(b_1.uniform_row_length(), 1)
2693
2694  def no_broadcast_a():
2695    return control_flow_ops.cond(
2696        can_broadcast_b, true_fn=broadcast_b, false_fn=broadcast_noop)
2697
2698  can_broadcast_a = math_ops.equal(a_1.uniform_row_length(), 1)
2699
2700  broadcast_asserts = [
2701      check_ops.assert_equal(
2702          math_ops.logical_or(
2703              math_ops.logical_or(can_broadcast_a, can_broadcast_b),
2704              math_ops.equal(a_1.uniform_row_length(),
2705                             b_1.uniform_row_length())), True)
2706  ]
2707
2708  result = control_flow_ops.cond(
2709      can_broadcast_a, true_fn=broadcast_a, false_fn=no_broadcast_a)
2710
2711  [c_1_uniform_row_length, ac_1_gather_index, bc_1_gather_index] = [
2712      control_flow_ops.with_dependencies(broadcast_asserts, x) for x in result
2713  ]
2714
2715  c_1 = RowPartition.from_uniform_row_length(
2716      c_1_uniform_row_length,
2717      nvals=c_1_uniform_row_length * ac_0.dest_nrows(),
2718      nrows=ac_0.dest_nrows())
2719  ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index)
2720  bc_1 = _LayerBroadcaster.from_gather_index(bc_1_gather_index)
2721  return [c_1, ac_1, bc_1]
2722
2723
2724def _broadcast_dynamic_shape_next_layer(
2725    ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition,
2726    b_1: RowPartition
2727) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]:
2728  r"""Broadcast target and next layer broadcaster of two dynamic shapes.
2729
2730     *--ac_0-->*<--bc_0--*
2731     |         |         |
2732    a_1       c_1       b_1
2733     |         |         |
2734     V         V         V
2735     *--ac_1-->*<--bc_1--*
2736
2737  Args:
2738    ac_0: _LayerBroadcaster from a to c in the previous layer.
2739    bc_0: _LayerBroadcaster from b to c in the previous layer.
2740    a_1: a RowPartition for the next layer of a.
2741    b_1: a RowPartition for the next layer of b.
2742
2743  Returns:
2744    (c_1, ac_1, bc_1)
2745    c_1: a RowPartition for the next layer of the dynamic shape.
2746    ac_1: _LayerBroadcaster from a to c in the next layer.
2747    bc_1: _LayerBroadcaster from b to c in the next layer.
2748  """
2749  if not isinstance(ac_0, _LayerBroadcaster):
2750    raise TypeError("ac_0 should be a _LayerBroadcaster")
2751  if not isinstance(bc_0, _LayerBroadcaster):
2752    raise TypeError("bc_0 should be a _LayerBroadcaster")
2753  if not isinstance(a_1, RowPartition):
2754    raise TypeError("a_1 should be a RowPartition")
2755  if not isinstance(b_1, RowPartition):
2756    raise TypeError("b_1 should be a RowPartition")
2757
2758  if a_1.is_uniform():
2759    if b_1.is_uniform():
2760      return _broadcast_dynamic_shape_next_layer_both_uniform(
2761          ac_0, bc_0, a_1, b_1)
2762    else:
2763      return _broadcast_dynamic_shape_next_layer_half_ragged(
2764          ac_0, bc_0, a_1, b_1)
2765  else:
2766    if b_1.is_uniform():
2767      [c_1, bc_1, ac_1] = _broadcast_dynamic_shape_next_layer_half_ragged(  # pylint: disable=arguments-out-of-order
2768          bc_0, ac_0, b_1, a_1)
2769      return (c_1, ac_1, bc_1)
2770    else:
2771      # If neither shape is uniform, we cannot broadcast the dimension.
2772      [ac_1, c_1a] = _broadcast_half(ac_0, a_1)
2773      [bc_1, c_1b] = _broadcast_half(bc_0, b_1)
2774      check_valid = [
2775          check_ops.assert_equal(c_1a.row_splits(), c_1b.row_splits())
2776      ]
2777      return (c_1a._with_dependencies(check_valid),  # pylint: disable=protected-access
2778              ac_1.with_dependencies(check_valid),
2779              bc_1.with_dependencies(check_valid))
2780
2781
2782def _broadcast_dynamic_shape_from_rps(
2783    a_zero: _LayerBroadcaster, b_zero: _LayerBroadcaster,
2784    a_rps: Sequence[RowPartition], b_rps: Sequence[RowPartition]
2785) -> Tuple[Sequence[RowPartition], Sequence[_LayerBroadcaster],
2786           Sequence[_LayerBroadcaster]]:
2787  """Create BroadcastLayers from two shapes to a target shape.
2788
2789
2790      *--a_zero->*<-b_zero-*
2791      |          |         |
2792   a_rps[0]    c_rps[0]  b_rps[0]
2793      |          |         |
2794      V          V         V
2795      *--ac[1]-->*<-bc[1]--*
2796      |          |         |
2797   a_rps[1]   c_rps[0]   b_rps[1]
2798      |          |         |
2799      V          V         V
2800      *--ac[2]-->*<-bc[2]--*
2801
2802  Note: ac[0]=a_zero, and bc[0]=b_zero.
2803  Args:
2804    a_zero: broadcaster from rows of a_rps[0] to target shape.
2805    b_zero: broadcaster from rows of b_rps[0] to target shape.
2806    a_rps: RowPartitions of first shape.
2807    b_rps: RowPartitions of second shape, equal in length to a_rps.
2808
2809  Returns:
2810    (c_rps, ac, bc) where:
2811    c_rps: RowPartitions of target shape.
2812    ac: layers broadcasting from the first shape.
2813    bc: layers broadcasting from the second shape.
2814  """
2815  assert len(a_rps) == len(b_rps)
2816  if a_rps:
2817    (c_1, ac_1,
2818     bc_1) = _broadcast_dynamic_shape_next_layer(a_zero, b_zero, a_rps[0],
2819                                                 b_rps[0])
2820    (c_suffix, a_layers,
2821     b_layers) = _broadcast_dynamic_shape_from_rps(ac_1, bc_1, a_rps[1:],
2822                                                   b_rps[1:])
2823
2824    return ([c_1] + c_suffix, [ac_1] + a_layers, [bc_1] + b_layers)
2825  else:
2826    return ([], [], [])
2827
2828
2829def _get_broadcast_num_row_partitions(a: DynamicRaggedShape,
2830                                      b: DynamicRaggedShape):
2831  """Returns broadcast_dynamic_shape(a, b).num_row_partitions."""
2832  # Assumes rank and num_row_partitions are not None.
2833  if (a.num_row_partitions == 0 and b.num_row_partitions == 0):
2834    return 0
2835  expanded_num_row_partitions_a = a.num_row_partitions + max(0, b.rank - a.rank)
2836  expanded_num_row_partitions_b = b.num_row_partitions + max(0, a.rank - b.rank)
2837
2838  if a.num_row_partitions == 0:
2839    return expanded_num_row_partitions_b
2840
2841  if b.num_row_partitions == 0:
2842    return expanded_num_row_partitions_a
2843
2844  return max(expanded_num_row_partitions_a, expanded_num_row_partitions_b)
2845
2846
2847# pylint: disable=protected-access
2848def _broadcast_dynamic_shape_extended_complete(
2849    a: DynamicRaggedShape, b: DynamicRaggedShape, b_rps: Sequence[RowPartition],
2850    c_suffix: Sequence[RowPartition], ac: Sequence[_LayerBroadcaster],
2851    bc_suffix: Sequence[_LayerBroadcaster]
2852) -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster]:
2853  """Helper for broadcast_dynamic_shape_extended."""
2854  c_prefix = b_rps[:-len(c_suffix)]
2855  bc_prefix_length = b.rank - len(bc_suffix)
2856  bc_prefix = [
2857      _LayerBroadcaster.get_identity_broadcaster(b._num_slices_in_dimension(i))
2858      for i in range(bc_prefix_length)
2859  ]
2860  c_num_row_partitions = _get_broadcast_num_row_partitions(a, b)
2861
2862  c_raw = DynamicRaggedShape.from_row_partitions(c_prefix + tuple(c_suffix))
2863  c = c_raw._with_num_row_partitions(c_num_row_partitions)
2864  return (c, _Broadcaster(a, c, ac), _Broadcaster(b, c, bc_prefix + bc_suffix))
2865
2866
2867def _broadcast_dynamic_shape_extended_helper(
2868    a: DynamicRaggedShape, b: DynamicRaggedShape
2869) -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster]:
2870  """Helper for broadcast_dynamic_shape_extended.
2871
2872  Here, we force:
2873    a.rank <= b.rank
2874    2 <= b.rank
2875    1 <= a.rank
2876  Args:
2877    a: a DynamicRaggedShape
2878    b: a DynamicRaggedShape
2879
2880  Returns:
2881    A triple of a shape and two broadcasters.
2882  """
2883  assert a.rank <= b.rank
2884  assert 2 <= b.rank
2885  assert 1 <= a.rank
2886  a_rps = a._as_row_partitions()  # pylint: disable=protected-access
2887  b_rps = b._as_row_partitions()  # pylint: disable=protected-access
2888
2889  if len(a_rps) < len(b_rps):
2890    # Note: this includes the case where len(a_rps)==0.
2891    # Here we begin at -1, one dimension before a_rps[0].
2892    # neg_one_a_rp  | b_rps[-(len(a_rps)+1)]
2893    # a_rps[0]      | b_rps[-len(a_rps)]
2894    # a_rps[1]      | b_rps[1-len(a_rps)]
2895    # ...           | ...
2896    # a_rps[-1]     | b_rps[-1]
2897
2898    a_nrows = a[0]
2899    a_nrows_static = tensor_util.constant_value(a_nrows)
2900    if a_nrows_static is not None:
2901      a_nrows = a_nrows_static
2902
2903    neg_one_a_rp = RowPartition.from_uniform_row_length(
2904        uniform_row_length=a_nrows, nrows=1, nvals=a_nrows)
2905    neg_one_b_rp = b_rps[-(len(a_rps) + 1)]
2906    (neg_one_ac, neg_one_bc) = _broadcast_dynamic_shape_first_layer(
2907        constant_op.constant(1, dtype=b_rps[0].dtype), neg_one_b_rp.nrows())
2908
2909    # The first part of the solution.
2910    (c_zero, ac_zero,
2911     bc_zero) = _broadcast_dynamic_shape_next_layer(neg_one_ac, neg_one_bc,
2912                                                    neg_one_a_rp, neg_one_b_rp)
2913    b_rps_tail = b_rps[-len(a_rps):] if len(a_rps) >= 1 else []
2914
2915    (c_suffix, ac_layers,
2916     bc_layers) = _broadcast_dynamic_shape_from_rps(ac_zero, bc_zero, a_rps,
2917                                                    b_rps_tail)
2918
2919    return _broadcast_dynamic_shape_extended_complete(
2920        a=a,
2921        b=b,
2922        b_rps=b_rps,
2923        c_suffix=[c_zero] + c_suffix,
2924        ac=[ac_zero] + ac_layers,
2925        bc_suffix=[neg_one_bc, bc_zero] + bc_layers)
2926
2927  else:
2928    assert len(a_rps) == len(b_rps)
2929    (ac_zero,
2930     bc_zero) = _broadcast_dynamic_shape_first_layer(a_rps[0].nrows(),
2931                                                     b_rps[0].nrows())
2932
2933    (c_rps, a_layers,
2934     b_layers) = _broadcast_dynamic_shape_from_rps(ac_zero, bc_zero, a_rps,
2935                                                   b_rps)
2936    return _broadcast_dynamic_shape_extended_complete(
2937        a=a,
2938        b=b,
2939        b_rps=b_rps,
2940        c_suffix=c_rps,
2941        ac=[ac_zero] + a_layers,
2942        bc_suffix=[bc_zero] + b_layers)
2943
2944
2945def _fix_start_index(index, rank, num_row_partitions):
2946  """Slice indexes are always silently truncated."""
2947  if index < 0:
2948    if rank is None:
2949      raise ValueError(
2950          "Rank must be known to use __getitem__ on a negative index.")
2951    index = rank + index
2952  if index < 0:
2953    index = 0
2954  if (num_row_partitions > 0 and index <= num_row_partitions + 1):
2955    # The rank is always >= num_row_partitions + 1 if num_row_partitions > 0.
2956    return index
2957  if index == 0:
2958    return index
2959  if rank is None:
2960    raise ValueError("Rank must be known to use __getitem__ on a large index.")
2961  if index >= rank:
2962    index = rank
2963  return index
2964
2965
2966def _fix_stop_index(index, rank):
2967  """Slice indexes are always silently truncated."""
2968  if index is None:
2969    if rank is None:
2970      raise ValueError("Rank must be known to use __getitem__ without a stop.")
2971    index = rank
2972  if index < 0:
2973    if rank is None:
2974      raise ValueError(
2975          "Rank must be known to use __getitem__ on a negative index.")
2976    index = rank + index
2977  if index < 0:
2978    index = 0
2979  if rank is not None:
2980    index = min(rank, index)
2981  return index
2982
2983
2984def _first_layer_gather_index(nrows_source, nrows_target):
2985  """Return the first layer gather_index.
2986
2987  Args:
2988    nrows_source: the number of rows in the source.
2989    nrows_target: the number of rows in the target.
2990
2991  Returns:
2992    A tensor, usable as a gather_index for a _LayerBroadcaster.
2993  """
2994
2995  def gi_broadcast_first():
2996    return array_ops.zeros(nrows_target, dtype=nrows_target.dtype)
2997
2998  def gi_no_broadcast_first():
2999    gather_index = math_ops.range(nrows_target, dtype=nrows_target.dtype)
3000    return gather_index
3001
3002  do_broadcast = math_ops.equal(nrows_source,
3003                                constant_op.constant(1, nrows_source.dtype))
3004  nrows_equal = math_ops.equal(nrows_source, nrows_target)
3005  can_broadcast = check_ops.assert_equal(
3006      math_ops.logical_or(do_broadcast, nrows_equal),
3007      True,
3008      message="Cannot broadcast")
3009
3010  gather_index = control_flow_ops.cond(
3011      do_broadcast, true_fn=gi_broadcast_first, false_fn=gi_no_broadcast_first)
3012
3013  return control_flow_ops.with_dependencies([can_broadcast], gather_index)
3014
3015
3016def _next_layer_gather_index(bc, original_rp, broadcast_rp):
3017  r"""Create the next layer gather_index whether or not a broadcast happens.
3018
3019     *----------bc-------->*
3020     |                     |
3021  original_rp           broadcast_rp
3022     |                     |
3023    \|/                   \|/
3024     *--next_broadcaster-->*
3025
3026  Args:
3027    bc: the old broadcaster.
3028    original_rp: the original row partition.
3029    broadcast_rp: the target row partition.
3030
3031  Returns:
3032    the gather_index for next_broadcaster.
3033  Raises:
3034    InvalidArgumentError if the shapes are incompatible.
3035  """
3036  old_value_rowids = array_ops.gather(bc.gather_index,
3037                                      broadcast_rp.value_rowids())
3038
3039  def gi_no_broadcast():
3040    # TODO(martinz): decide if row_splits or row_starts should be used here.
3041    old_row_starts = array_ops.gather(original_rp.row_splits(),
3042                                      old_value_rowids)
3043    expected_row_lengths = array_ops.gather(
3044        params=original_rp.row_lengths(), indices=bc.gather_index)
3045    actual_row_lengths = broadcast_rp.row_lengths()
3046    check_valid = check_ops.assert_equal(
3047        expected_row_lengths, actual_row_lengths, message="Cannot broadcast")
3048    gather_index = old_row_starts + broadcast_rp.offsets_in_rows()
3049    return control_flow_ops.with_dependencies([check_valid], gather_index)
3050
3051  def gi_broadcast():
3052    # Several optimizations can occur here.
3053    # old_row_starts == old_value_rowids, because:
3054    #   if you are broadcasting, then the source has uniform row length of 1,
3055    #   implying original_rp.row_splits == tf.range(orgininal_rp.nvals + 1)
3056    # When broadcasting, there is no need to add offsets to the
3057    # source, because the source has size 1.
3058    # Also, this is always valid, because we enforce source and destination
3059    # have uniform_row_length.
3060    return old_value_rowids
3061
3062  if not original_rp.is_uniform():
3063    return gi_no_broadcast()
3064
3065  do_broadcast = math_ops.equal(original_rp.uniform_row_length(),
3066                                constant_op.constant(1, original_rp.dtype))
3067  gather_index = control_flow_ops.cond(
3068      do_broadcast, true_fn=gi_broadcast, false_fn=gi_no_broadcast)
3069
3070  return gather_index
3071
3072
3073def _flat_values_shape(rt):
3074  if isinstance(rt, ragged_tensor.RaggedTensor):
3075    return array_ops.shape(rt.flat_values)
3076  return rt.flat_values.shape
3077
3078
3079def _to_row_partitions_and_nvals_from_lengths(
3080    lengths: Sequence[Union[int, Sequence[int]]],
3081    dtype=None) -> Tuple[Sequence[RowPartition], int]:
3082  """Allow ragged and uniform shapes to be specified.
3083
3084  For example, [2, [2,1], 2] represents a shape like:
3085  [[[0, 0], [0, 0]], [[0, 0]]]
3086
3087  Args:
3088    lengths: a list of integers and lists of integers.
3089    dtype: dtype of the shape (tf.int32 or tf.int64)
3090
3091  Returns:
3092    a sequence of RowPartitions, and the number of values of the last partition.
3093  """
3094  size_so_far = lengths[0]
3095  result = []
3096  for current_lengths in lengths[1:]:
3097    if isinstance(current_lengths, int):
3098      nrows = size_so_far
3099      nvals = current_lengths * nrows
3100      size_so_far = nvals
3101      result.append(
3102          RowPartition.from_uniform_row_length(
3103              current_lengths, nvals, nrows=nrows, dtype_hint=dtype))
3104    else:
3105      if size_so_far != len(current_lengths):
3106        raise ValueError("Shape not consistent.")
3107      result.append(
3108          RowPartition.from_row_lengths(current_lengths, dtype_hint=dtype))
3109      size_so_far = sum(current_lengths)
3110  return (result, size_so_far)
3111
3112
3113def _element_to_string(x):
3114  """element to a string within a list."""
3115  if x is Ellipsis:
3116    return "..."
3117  if isinstance(x, str):
3118    return "'" + x + "'"
3119  return str(x)
3120
3121
3122def _list_tail_with_ellipsis(arr):
3123  """Print the tail of a list where the list might have an ellipsis."""
3124  if not arr:
3125    return "]"
3126  else:
3127    return ", " + _element_to_string(arr[0]) + _list_tail_with_ellipsis(arr[1:])
3128
3129
3130def _list_with_ellipsis_to_str(arr):
3131  """Print a list that might have ellipsis."""
3132  if not arr:
3133    return "[]"
3134  return "[" + _element_to_string(arr[0]) + _list_tail_with_ellipsis(arr[1:])
3135
3136
3137def _is_int_or_tuple_of_ints(x):
3138  if isinstance(x, int):
3139    return True
3140  if not isinstance(x, tuple):
3141    return False
3142  for y in x:
3143    if not isinstance(y, int):
3144      return False
3145  return True
3146
3147
3148def _alt_inner_shape_from_tensor_shape(shape, dtype, new_inner_rank):
3149  """Helper for _alt_inner_shape, used directly in _with_num_row_partitions."""
3150  if new_inner_rank == 1:
3151    return constant_op.constant([shape.num_elements()], dtype=dtype)
3152  new_inner_rank_tail_length = new_inner_rank - 1
3153  inner_shape_tail = shape[-new_inner_rank_tail_length:].as_list()
3154  first_dim = shape[:-new_inner_rank_tail_length].num_elements()
3155  return constant_op.constant([first_dim] + inner_shape_tail, dtype=dtype)
3156
3157
3158def _safe_floor_div(dividend: tensor_shape.Dimension,
3159                    divisor: tensor_shape.Dimension) -> tensor_shape.Dimension:
3160  if tensor_shape.dimension_value(divisor) == 0:
3161    return None
3162  return dividend // divisor
3163
3164
3165# TODO(b/218932570)
3166def _reduce_prod_patch(x):
3167  if x.dtype == dtypes.int64:
3168    return math_ops.cast(
3169        math_ops.reduce_prod(math_ops.cast(x, dtypes.int32)), dtypes.int64)
3170  return math_ops.reduce_prod(x)
3171
3172
3173# Type alias for shape encoded as a DynamicRaggedShape or a Tensor.
3174DenseOrRaggedShape = Union[DynamicRaggedShape, core.TensorLike]
3175
3176
3177def _merge_row_partitions(
3178    row_partitions: Sequence[RowPartition]) -> RowPartition:
3179  # TODO(martinz): handle uniform splits.
3180  # TODO(martinz): consider using value_row_ids if present.
3181  # Note: this probably won't be called with len(row_partitions)==1, so no
3182  # need to optimize.
3183  row_splits = row_partitions[0].row_splits()
3184  for rp in row_partitions[1:]:
3185    row_splits = array_ops.gather(rp.row_splits(), row_splits)
3186  return RowPartition.from_row_splits(row_splits)
3187
3188
3189def _merge_inner_shape(
3190    inner_shape: ops.Tensor,
3191    static_inner_shape: tensor_shape.TensorShape,
3192    outer_axis: int,
3193    inner_axis: int) -> Tuple[ops.Tensor, tensor_shape.TensorShape]:
3194  """Merge the inner shape of a DynamicRaggedShape."""
3195  prefix = inner_shape[:outer_axis]
3196  suffix = inner_shape[inner_axis + 1:]
3197
3198  internal = inner_shape[outer_axis:inner_axis + 1]
3199  internal_value = [_reduce_prod_patch(internal)]
3200  new_internal = array_ops.concat([prefix, internal_value, suffix], axis=0)
3201  prefix_static = static_inner_shape[:outer_axis]
3202  suffix_static = static_inner_shape[inner_axis+1:]
3203  internal_static = static_inner_shape[outer_axis:inner_axis+1]
3204  internal_value_static = tensor_shape.TensorShape(
3205      [internal_static.num_elements()])
3206  new_internal_static = prefix_static + internal_value_static + suffix_static
3207
3208  return (new_internal, new_internal_static)
3209
3210
3211def _batch_rp_spec(rp_spec: RowPartitionSpec,
3212                   batch_size: Optional[int]) -> RowPartitionSpec:
3213  """Batches a RowPartitionSpec.
3214
3215  Given a RowPartitionSpec and a batch_size, create a RowPartitionSpec that
3216  will be the spec for the concatenation of batch_size RowPartitions.
3217
3218  A RowPartition can be considered a transformation from a list of a given
3219  length to a list of lists. Assume rp_a is a map from list_a to nlist_a,
3220  And rp_b is a map from list_b to nlist_b. concat(rp_a, rp_b) is a
3221  transform of concat(list_a, list_b) to concat(nlist_a, nlist_b).
3222
3223  If batch_size is None, then have the spec be able to handle an arbitrary
3224  number of RowPartitions.
3225
3226  Args:
3227    rp_spec: a RowPartitionSpec for all the RowPartitions to be concatenated.
3228    batch_size: the number of rp_specs to be concatenated.
3229  Returns:
3230    a batched RowPartitionSpec.
3231  """
3232  if batch_size is None:
3233    return RowPartitionSpec(uniform_row_length=rp_spec.uniform_row_length,
3234                            dtype=rp_spec.dtype)
3235  nrows = None if rp_spec.nrows is None else rp_spec.nrows * batch_size
3236  nvals = None if rp_spec.nvals is None else rp_spec.nvals * batch_size
3237  return RowPartitionSpec(
3238      nrows=nrows, nvals=nvals, uniform_row_length=rp_spec.uniform_row_length,
3239      dtype=rp_spec.dtype)
3240
3241
3242def _batch_rp_spec_head(old_head: RowPartitionSpec,
3243                        batch_size: Optional[int]) -> RowPartitionSpec:
3244  """Creates a RowPartitionSpec representing the new dimension created."""
3245  nvals = None if (old_head.nrows is None or
3246                   batch_size is None) else batch_size * old_head.nrows
3247  return RowPartitionSpec(
3248      nrows=batch_size, nvals=nvals, uniform_row_length=old_head.nrows,
3249      dtype=old_head.dtype)
3250
3251
3252def _batch_static_inner_shape(
3253    old_shape: tensor_shape.TensorShape,
3254    batch_size: Optional[int]) -> tensor_shape.TensorShape:
3255  """Returns a copy of old_shape with axis=0 multiplied by batch_size.
3256
3257  Only use if this is the inner_shape of a DynamicRaggedShape.Spec with one
3258  or more row partitions.
3259
3260  Args:
3261    old_shape: the original inner_shape.
3262    batch_size: the batch size.
3263
3264  Returns:
3265    a new shape.
3266  """
3267  head_dim = tensor_shape.dimension_at_index(old_shape, 0) * batch_size
3268  return head_dim + old_shape[1:]
3269
3270
3271def _batch_tensor_shape(old_shape: tensor_shape.TensorShape,
3272                        batch_size: int) -> tensor_shape.TensorShape:
3273  return tensor_shape.TensorShape([batch_size]) + old_shape
3274
3275
3276def _unbatch_static_inner_shape(
3277    old_shape: tensor_shape.TensorShape,
3278    batch_size: Optional[int]) -> tensor_shape.TensorShape:
3279  """Unbatch a static_inner_shape when num_row_partitions > 0."""
3280  head_dim = tensor_shape.dimension_at_index(old_shape, 0) // batch_size
3281  return head_dim + old_shape[1:]
3282
3283
3284# Copied from ragged_array_ops.py
3285def ones(shape: DynamicRaggedShape,
3286         dtype=dtypes.float32,
3287         name: Optional[str] = None) -> ragged_tensor.RaggedOrDense:
3288  """Returns ones shaped like x."""
3289  flat_values = array_ops.ones(shape.inner_shape, dtype=dtype, name=name)
3290  return ragged_tensor.RaggedTensor._from_nested_row_partitions(  # pylint: disable=protected-access
3291      flat_values, shape.row_partitions)
3292