1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Shapes & broadcasting for RaggedTensors. 16 17TODO(martinz): make this suitable for output for tf.shape 18TODO(martinz): replace ragged_tensor_shape with this. 19""" 20 21 22import abc 23from typing import Any, Iterable, Optional, Sequence, Tuple, Union 24 25import numpy as np 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import extension_type 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import check_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops.ragged import ragged_tensor 38from tensorflow.python.ops.ragged.row_partition import RowPartition 39from tensorflow.python.ops.ragged.row_partition import RowPartitionSpec 40from tensorflow.python.types import core 41from tensorflow.python.util import dispatch 42from tensorflow.python.util.tf_export import tf_export 43 44 45class _DynamicRaggedShapeBatchEncoder(extension_type.ExtensionTypeBatchEncoder): 46 """A batch encoder for DynamicRaggedShape below.""" 47 48 def batch(self, spec: "DynamicRaggedShape.Spec", 49 batch_size) -> "DynamicRaggedShape.Spec": 50 if spec.num_row_partitions: 51 new_head = _batch_rp_spec_head(spec._row_partitions[0], batch_size) # pylint:disable=protected-access 52 new_tail = [_batch_rp_spec(rp, batch_size) for rp in spec._row_partitions] # pylint:disable=protected-access 53 new_rp = [new_head] + new_tail 54 new_static_inner_shape = _batch_static_inner_shape( 55 spec._static_inner_shape, batch_size) # pylint:disable=protected-access 56 57 return DynamicRaggedShape.Spec( 58 row_partitions=new_rp, 59 static_inner_shape=new_static_inner_shape, 60 dtype=spec.dtype) 61 elif batch_size is None: 62 if spec.inner_rank == 0: 63 return DynamicRaggedShape.Spec._from_tensor_shape([None], # pylint:disable=protected-access 64 0, 65 dtype=spec.dtype) 66 else: 67 # Might be None 68 new_head = RowPartitionSpec(uniform_row_length=spec._dimension(0), # pylint:disable=protected-access 69 dtype=spec.dtype) 70 new_static_inner_shape = _batch_static_inner_shape( 71 spec._static_inner_shape, batch_size) # pylint:disable=protected-access 72 return DynamicRaggedShape.Spec( 73 row_partitions=[new_head], 74 static_inner_shape=new_static_inner_shape, 75 dtype=spec.dtype) 76 else: 77 78 return DynamicRaggedShape.Spec( 79 row_partitions=[], 80 static_inner_shape=_batch_tensor_shape(spec._static_inner_shape, # pylint:disable=protected-access 81 batch_size), 82 dtype=spec.dtype) 83 84 def unbatch(self, 85 spec: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape.Spec": 86 if spec.num_row_partitions: 87 result = [] 88 head = spec._row_partitions[0] # pylint:disable=protected-access 89 scale = None if head.uniform_row_length is None else head.nrows 90 91 for rp in spec._row_partitions[1:]: # pylint:disable=protected-access 92 if scale is None: 93 result.append( 94 RowPartitionSpec( 95 nrows=None, 96 nvals=None, 97 uniform_row_length=rp.uniform_row_length, 98 dtype=spec.dtype)) 99 else: 100 nrows = None if rp.nrows is None else rp.nrows//scale 101 if rp.uniform_row_length is None: 102 scale = None 103 result.append(RowPartitionSpec(nrows=nrows, 104 nvals=None, 105 uniform_row_length=None, 106 dtype=spec.dtype)) 107 else: 108 result.append( 109 RowPartitionSpec( 110 nrows=nrows, 111 nvals=rp.nvals // scale, 112 uniform_row_length=rp.uniform_row_length, 113 dtype=spec.dtype)) 114 return DynamicRaggedShape.Spec( 115 row_partitions=result, 116 static_inner_shape=_unbatch_static_inner_shape( 117 spec._static_inner_shape, scale), # pylint:disable=protected-access 118 dtype=spec.dtype) 119 else: # spec.num_row_partitions == 0 120 return DynamicRaggedShape.Spec( 121 row_partitions=[], 122 static_inner_shape=spec._static_inner_shape[1:], # pylint:disable=protected-access 123 dtype=spec.dtype) 124 125 def decode(self, spec: "DynamicRaggedShape.Spec", encoding 126 ) -> "DynamicRaggedShape": 127 return DynamicRaggedShape.from_tensor(encoding, dtype=spec.dtype) 128 129 def encode(self, spec: "DynamicRaggedShape.Spec", value, minimum_rank=0 130 ) -> Union[ragged_tensor.RaggedTensor, ops.Tensor]: 131 return ones(value, dtype=dtypes.bool) 132 133 def encoding_specs( 134 self, 135 spec: "DynamicRaggedShape.Spec" 136 ) -> Union[ragged_tensor.RaggedTensorSpec, tensor_spec.TensorSpec]: 137 if spec.rank != 0: 138 ragged_rank = spec.num_row_partitions 139 else: 140 # special case: need to unbatch twice to get ragged tensor. 141 ragged_rank = -1 142 return ragged_tensor.RaggedTensorSpec( 143 shape=spec._to_tensor_shape(), # pylint:disable=protected-access 144 dtype=dtypes.bool, 145 ragged_rank=ragged_rank, 146 row_splits_dtype=spec.dtype) 147 148 149# TODO(martinz): allow inner_shape to be a fully defined TensorShape. 150# A "fully defined TensorShape" means one where the rank and all dimensions are 151# known. 152# Allowing inner_shape might mean allowing inner_shape to be initialized by 153# a fully defined TensorShape, or it might mean that you can actually store 154# TensorShape in the inner_shape field. This could conceivably construct 155# a DynamicRaggedShape that was dtype agnostic. 156# 157# TODO(martinz): unify the impl of the determination of index type across 158# RowPartition and DynamicRaggedShape. 159@tf_export("experimental.DynamicRaggedShape") 160class DynamicRaggedShape(extension_type.BatchableExtensionType): 161 """The shape of a ragged or dense tensor. 162 163 Ragged shapes are encoded using two fields: 164 165 * `inner_shape`: An integer vector giving the shape of a dense tensor. 166 * `row_partitions`: A list of `RowPartition` objects, describing how 167 that flat shape should be partitioned to add ragged axes. 168 169 If a DynamicRaggedShape is the shape of a RaggedTensor rt, then: 170 1. row_partitions = rt._nested_row_partitions 171 (and thus len(row_partitions) > 0) 172 2. inner_shape is the shape of rt.flat_values 173 174 If a DynamicRaggedShape is the shape of a dense tensor t, then: 175 1. row_partitions = [] 176 2. inner_shape is the shape of t. 177 178 Examples: 179 180 The following table gives a few examples (where `RP(lengths)` is short 181 for `RowPartition.from_lengths(lengths)`): 182 183 Row Partitions | Inner Shape | Example Tensor 184 --------------------------- | ------------ | ---------------------------- 185 [] | [2, 3] | `[[1, 2, 3], [4, 5, 6]]` 186 [RP([2, 0, 3])] | [5] | `[[1, 2], [], [3, 4, 5]]` 187 [RP([2, 1])] | [3, 2] | `[[[1, 2], [3, 4]], [[5, 6]]]` 188 [RP([2, 1]), RP([2, 1, 2])] | [5] | `[[[1, 2], [3]], [[4, 5]]]` 189 """ 190 _row_partitions: Tuple[RowPartition, ...] 191 _inner_shape: ops.Tensor 192 _static_inner_shape: tensor_shape.TensorShape 193 __batch_encoder__ = _DynamicRaggedShapeBatchEncoder() 194 __name__ = "tf.DynamicRaggedShape" 195 196 def __init__(self, 197 row_partitions: Sequence[RowPartition], 198 inner_shape: core.TensorLike, 199 dtype: Optional[dtypes.DType] = None, 200 validate: bool = False, 201 static_inner_shape: ... = None): 202 """Core constructor for a DynamicRaggedShape. 203 204 Create a DynamicRaggedShape. This can be used to construct a 205 DynamicRaggedShape representing a ragged or dense shape. If row_partitions 206 is an empty list, then this is equivalent to a dense shape. 207 208 If row_partitions is specified, then the num_row_partitions will be equal 209 to len(row_partitions). There are several checks made. 210 Specifically: 211 1. Consecutive row_partitions must have consistent nvals and nrows. 212 2. The last row_partitions must have nvals equal to the first element of 213 inner_shape. 214 215 The inner_shape is converted to a tensor. 216 All row_partitions and the inner_shape are converted to the same dtype 217 (int64 or int32). 218 219 Args: 220 row_partitions: the row_partitions of the shape. 221 inner_shape: if len(row_partitions) > 0, the shape of the flat_values. 222 Otherwise, the shape of the tensor. 223 dtype: tf.int64, tf.int32, or None representing the preferred dtype. 224 validate: if true, dynamic validation is applied to the shape. 225 static_inner_shape: if len(row_partitions) > 0, the static shape of the 226 flat_values. Otherwise, the static shape of the tensor. 227 Should be convertible to a TensorShape. 228 """ 229 if not isinstance(row_partitions, Iterable): 230 raise TypeError( 231 "row_partitions should be a list of row partitions. Instead, got " + 232 str(row_partitions)) 233 for x in row_partitions: 234 if not isinstance(x, RowPartition): 235 raise TypeError("row_partitions contains " + str(x) + 236 " which is not a RowPartition") 237 dtype = _find_dtype_iterable(row_partitions, dtype) 238 dtype = _find_dtype(inner_shape, dtype) 239 if (isinstance(inner_shape, np.ndarray) and 240 inner_shape.dtype == np.int32 and dtype is None): 241 dtype = dtypes.int32 242 dtype = _find_dtype(dtypes.int64, dtype) 243 244 row_partitions = tuple([rp.with_dtype(dtype) for rp in row_partitions]) 245 self._row_partitions = row_partitions 246 self._inner_shape = ops.convert_to_tensor( 247 inner_shape, dtype_hint=dtype, name="inner_dim_sizes") 248 if self._inner_shape.dtype != dtype: 249 self._inner_shape = math_ops.cast(self._inner_shape, dtype) 250 251 checks = [] 252 # Validate shapes. 253 if self._row_partitions: 254 for axis, rp in enumerate(self._row_partitions): 255 if axis > 0: 256 previous_row_partition = self._row_partitions[axis - 1] 257 msg = ("RowPartitions in DynamicRaggedShape do not align " 258 f"between {axis - 1} and {axis}") 259 static_nrows = rp.static_nrows 260 static_nvals = previous_row_partition.static_nvals 261 if (static_nrows is not None) and (static_nvals is not None): 262 if static_nrows != static_nvals: 263 raise ValueError(msg) 264 else: 265 continue 266 if validate: 267 checks.append( 268 check_ops.assert_equal( 269 previous_row_partition.nvals(), 270 rp.nrows(), 271 message=msg)) 272 273 self._inner_shape.shape.assert_has_rank(1) 274 275 self._static_inner_shape = tensor_util.constant_value_as_shape( 276 self._inner_shape) 277 if static_inner_shape is not None: 278 self._static_inner_shape = self._static_inner_shape.merge_with( 279 static_inner_shape) 280 281 if row_partitions: 282 last_row_partition = row_partitions[-1] 283 static_nvals = last_row_partition.static_nvals 284 static_inner_shape_nvals = tensor_shape.dimension_value( 285 self._static_inner_shape[0]) 286 if static_nvals is not None and static_inner_shape_nvals is not None: 287 if static_nvals != static_inner_shape_nvals: 288 raise ValueError("Last row partition does not match inner_shape.") 289 elif validate: 290 checks.append( 291 check_ops.assert_equal( 292 last_row_partition.nvals(), 293 self._inner_shape[0], 294 message="Last row partition does not match inner_shape.")) 295 if checks: 296 self._inner_shape = control_flow_ops.with_dependencies( 297 checks, self._inner_shape, name="inner_shape_validated") 298 self._row_partitions = [ 299 rp._with_dependencies(checks) for rp in self._row_partitions # pylint: disable=protected-access 300 ] 301 302 @classmethod 303 def from_lengths(cls, 304 lengths: Sequence[Union[Sequence[int], int]], 305 num_row_partitions=None, 306 dtype=dtypes.int64): 307 """Creates a shape with the given lengths and num_row_partitions. 308 309 The lengths can either be a nonnegative int or a list of nonnegative ints. 310 311 If num_row_partitions is None, then the minimal num_row_partitions is used. 312 313 For example, [2, (3, 2)] is the shape of [[0, 0, 0], [0, 0]], and 314 [2, 2] is the shape of [[0, 0], [0, 0]] 315 316 This chooses the minimal num_row_partitions required (including zero). 317 318 The following table gives a few examples (where `RP(lengths)` is short 319 for `RowPartition.from_lengths(lengths)`): 320 321 For example: 322 from_lengths | row_partitions | inner_shape 323 ---------------------- | --------------------------| ------------- 324 [] | [] | [] 325 [2, (3, 2)] | [RP([3, 2])] | [5] 326 [2, 2] | [] | [2, 2] 327 [2, (3, 2), 7] | [RP([3, 2])] | [5, 7] 328 [2, (2, 2), 3] | [RP([2, 2])] | [4, 3] 329 [2, 2, 3] | [] | [2, 2, 3] 330 [2, (2, 1), (2, 0, 3)] | [RP(2, 1), RP([2, 0, 3])] | [5] 331 332 If we want the row partitions to end with uniform row partitions, then 333 we can set num_row_partitions. 334 335 For example, 336 below URP(3, 12) is RowPartition.from_uniform_row_length(3, 12) 337 338 from_lengths | num_row_partitions | row_partitions | inner_shape 339 ---------------| -------------------|--------------------------|------------ 340 [2, (3, 2), 2] | 2 | [RP([3, 2]), URP(2, 10)] | [10] 341 [2, 2] | 1 | [URP(2, 4)] | [4] 342 [2, 2, 3] | 0 | [] | [2, 2, 3] 343 [2, 2, 3] | 1 | [URP(2, 4)] | [4, 3] 344 [2, 2, 3] | 2 | [URP(2, 4), URP(3, 12)] | [12] 345 346 347 348 Representing the shapes from init(): 349 350 from_lengths | Tensor Example 351 ------------------------ | ------------------------------ 352 `[2, 3]` | `[[1, 2, 3], [4, 5, 6]]` 353 `[3, (2, 0, 3)]` | `[[1, 2], [], [3, 4, 5]]` 354 `[2, (2, 1), 2]` | `[[[1, 2], [3, 4]], [[5, 6]]]` 355 `[2, (2, 1), (2, 1, 2)]` | `[[[1, 2], [3]], [[4, 5]]]` 356 357 Args: 358 lengths: the lengths of sublists along each axis. 359 num_row_partitions: the num_row_partitions of the result or None 360 indicating the minimum number of row_partitions. 361 dtype: the dtype of the shape (tf.int32 or tf.int64). 362 363 Returns: 364 a new DynamicRaggedShape 365 """ 366 if not isinstance(lengths, list): 367 raise ValueError("lengths should be a list") 368 for x in lengths: 369 if not _is_int_or_tuple_of_ints(x): 370 raise ValueError( 371 "element of lengths should be int or tuple of ints: instead %r" % 372 (x,)) 373 374 if num_row_partitions is None: 375 # Calculate the minimal num_row_partitions. 376 is_list = [not isinstance(x, int) for x in lengths] 377 if any(is_list): 378 # Last index when not a list. 379 num_row_partitions = len(is_list) - is_list[-1::-1].index(True) - 1 380 else: 381 num_row_partitions = 0 382 383 if not isinstance(num_row_partitions, int): 384 raise ValueError("num_row_partitions should be an int or None") 385 386 if not lengths: 387 if num_row_partitions > 0: 388 raise ValueError("num_row_partitions==0 for a scalar shape") 389 return DynamicRaggedShape([], [], dtype=dtype) 390 391 if not num_row_partitions < len(lengths): 392 raise ValueError( 393 "num_row_partitions should be less than `len(lengths)` " 394 "if shape is not scalar." 395 ) 396 397 if num_row_partitions > 0: 398 (row_partitions, nvals) = _to_row_partitions_and_nvals_from_lengths( 399 lengths[:num_row_partitions + 1]) 400 inner_shape = [nvals] + lengths[num_row_partitions + 1:] 401 return DynamicRaggedShape( 402 row_partitions, inner_shape, dtype=dtype) 403 else: 404 return DynamicRaggedShape([], lengths, dtype=dtype) 405 406 @classmethod 407 def from_row_partitions(cls, row_partitions, dtype=None): 408 """Create a shape from row_partitions. 409 410 Args: 411 row_partitions: a nonempty list of RowPartition objects. 412 dtype: the dtype to use, or None to use the row_partitions dtype. 413 414 Returns: 415 a DynamicRaggedShape with inner_rank==1. 416 """ 417 if not row_partitions: 418 raise ValueError("row_partitions cannot be empty") 419 inner_shape = [row_partitions[-1].nvals()] 420 return DynamicRaggedShape( 421 row_partitions, inner_shape, dtype=dtype) 422 423 @classmethod 424 def _from_inner_shape(cls, inner_shape, dtype=None): 425 """Create a shape from inner_shape, where num_row_partitions == 0.""" 426 return DynamicRaggedShape([], inner_shape, dtype=dtype) 427 428 # pylint: disable=protected-access 429 @classmethod 430 def from_tensor(cls, t, dtype=None): 431 """Constructs a ragged shape for a potentially ragged tensor.""" 432 if ragged_tensor.is_ragged(t): 433 return DynamicRaggedShape( 434 t._nested_row_partitions, _flat_values_shape(t), dtype=dtype) 435 else: 436 return DynamicRaggedShape._from_inner_shape( 437 array_ops.shape(t), dtype=dtype) 438 439 @property 440 def row_partitions(self): 441 """The row_partitions of the shape.""" 442 return self._row_partitions 443 444 @property 445 def num_row_partitions(self): 446 """The number of row_partitions of the shape.""" 447 return len(self._row_partitions) 448 449 @property 450 def dtype(self): 451 """The dtype of the shape -- one of tf.int32 or tf.int64.""" 452 return self._inner_shape.dtype 453 454 def _static_inner_shape_as_list(self, truncate_first): 455 """Returns the lengths of the inner shape (if rank known), or [...].""" 456 if self._static_inner_shape.rank is None: 457 return [...] 458 result = self._static_inner_shape.as_list() 459 if truncate_first: 460 return result[1:] 461 return result 462 463 def static_lengths(self, ragged_lengths=True): 464 """Returns a list of statically known axis lengths. 465 466 This represents what values are known. For each row partition, it presents 467 either the uniform row length (if statically known), 468 the list of row lengths, or none if it is not statically known. 469 For the inner shape, if the rank is known, then each dimension is reported 470 if known, and None otherwise. If the rank of the inner shape is not known, 471 then the returned list ends with an ellipsis. 472 473 Args: 474 ragged_lengths: If false, returns None for all ragged dimensions. 475 476 Returns: 477 A Sequence[Union[Sequence[int],int, None]] of lengths, with a possible 478 Ellipsis at the end. 479 """ 480 if self.num_row_partitions == 0: 481 return self._static_inner_shape_as_list(False) 482 first_dim = self.row_partitions[0].static_nrows 483 if isinstance(first_dim, tensor_shape.Dimension): 484 first_dim = first_dim.value 485 rp_dims = [first_dim] 486 for rp in self.row_partitions: 487 if rp.is_uniform(): 488 rp_dims.append(rp.static_uniform_row_length) 489 elif ragged_lengths: 490 const_vals = tensor_util.constant_value(rp.row_lengths()) 491 if const_vals is None: 492 rp_dims.append(None) 493 else: 494 rp_dims.append(tuple(const_vals.tolist())) 495 else: 496 rp_dims.append(None) 497 498 return rp_dims + self._static_inner_shape_as_list(True) 499 500 def __repr__(self): 501 lengths = _list_with_ellipsis_to_str(self.static_lengths()) 502 return ("<DynamicRaggedShape " 503 "lengths=%s num_row_partitions=%r>" % 504 (lengths, self.num_row_partitions)) 505 506 def _to_tensor_shape(self) -> tensor_shape.TensorShape: 507 """Returns a TensorShape representation of the shape.""" 508 lengths = self.static_lengths(ragged_lengths=False) 509 if not lengths: 510 return tensor_shape.TensorShape(()) 511 if lengths[-1] == Ellipsis: 512 return tensor_shape.TensorShape(None) 513 return tensor_shape.TensorShape(lengths) 514 515 def _slice_shape(self, start, stop): 516 """Returns a shape self[start:stop]. 517 518 If start == 0, then this truncates dimensions after stop. 519 If start != 0, then this will return a shape with num_row_partitions == 0. 520 521 See __getitem__. 522 523 Args: 524 start: the first dimension. 0 <= start <= rank 525 stop: the last dimension (exclusive). 0 <= stop <= rank 526 """ 527 if stop <= start: 528 return DynamicRaggedShape._from_inner_shape([]) 529 elif start == 0: 530 if stop <= self.num_row_partitions: 531 if stop == 1: 532 return DynamicRaggedShape._from_inner_shape( 533 [self.row_partitions[0].nrows()]) 534 new_row_partitions = self.row_partitions[:stop - 1] 535 new_inner_shape = [new_row_partitions[-1].nvals()] 536 return DynamicRaggedShape(new_row_partitions, new_inner_shape) 537 else: 538 if self.rank is None: 539 new_inner_rank = stop - self.num_row_partitions 540 new_inner_shape = self.inner_shape[:new_inner_rank] 541 return DynamicRaggedShape( 542 row_partitions=self.row_partitions, 543 inner_shape=new_inner_shape, 544 static_inner_shape=None, 545 validate=False) 546 547 elif self.rank <= stop: 548 return self 549 new_inner_rank = stop - self.num_row_partitions 550 new_inner_shape = self.inner_shape[:new_inner_rank] 551 return DynamicRaggedShape( 552 row_partitions=self.row_partitions, 553 inner_shape=new_inner_shape, 554 static_inner_shape=tensor_shape.TensorShape([None] 555 * new_inner_rank), 556 validate=False) 557 else: 558 if self.rank is None or stop < self.rank: 559 partial = self._slice_shape(0, stop) 560 else: 561 partial = self 562 563 for x in partial.row_partitions: 564 if not x.is_uniform(): 565 raise ValueError("All relevant dimensions must be uniform") 566 if partial.rank is None: 567 # TODO(martinz): Implement _with_num_row_partitions(0) if rank is 568 # unknown, and remove. 569 raise NotImplementedError( 570 "__getitem__[start:stop] where start > 0 not implemented") 571 572 return DynamicRaggedShape._from_inner_shape( 573 partial._with_num_row_partitions(0).inner_shape[start:]) 574 575 def _dimension(self, index): 576 """Return a dimension, if the dimension is not ragged (see __getitem__).""" 577 rank = self.rank 578 if not isinstance(index, int): 579 raise TypeError("index should be an int") 580 if (self.num_row_partitions == 0 or index > self.num_row_partitions + 1): 581 # If num_row_partitions > 0 and index <= num_row_partitions + 1, then 582 # we are safe. 583 if rank is None: 584 raise ValueError( 585 "Rank must be known to use __getitem__ on a large index.") 586 if index >= rank: 587 raise IndexError("Index is too big: " + str(index) + ">=" + str(rank)) 588 if index < 0: 589 raise IndexError("Index must be non-negative: " + str(index)) 590 elif not self.is_uniform(index): 591 raise ValueError("Index " + str(index) + " is not uniform") 592 elif index == 0 and self.num_row_partitions > 0: 593 static_nrows = self.row_partitions[0].static_nrows 594 if static_nrows is not None: 595 return constant_op.constant(static_nrows, dtype=self.dtype) 596 return self.row_partitions[0].nrows() 597 elif self.num_row_partitions == 0: 598 static_result = tensor_shape.dimension_value( 599 self._static_inner_shape[index]) 600 if static_result is not None: 601 return constant_op.constant(static_result, dtype=self.dtype) 602 return self.inner_shape[index] 603 elif index > self.num_row_partitions: 604 static_result = tensor_shape.dimension_value( 605 self._static_inner_shape[index - self.num_row_partitions]) 606 if static_result is not None: 607 return constant_op.constant(static_result, dtype=self.dtype) 608 609 return self.inner_shape[index - self.num_row_partitions] 610 else: 611 return self.row_partitions[index - 1].uniform_row_length() 612 613 def __getitem__(self, index): 614 """Returns a dimension or a slice of the shape. 615 616 Ragged shapes can have ragged dimensions that depend upon other dimensions. 617 Therefore, if you ask for a dimension that is ragged, this function returns 618 a ValueError. For similar reasons, if a slice is selected that includes 619 a ragged dimension without including the zero dimension, then this fails. 620 621 Any slice that does not start at zero will return a shape 622 with num_row_partitions == 0. 623 624 Args: 625 index: the index: can be an int or a slice. 626 627 Raises: 628 IndexError: if the index is not in range. 629 ValueError: if the rank is unknown, or a ragged rank is requested 630 incorrectly. 631 """ 632 rank = self.rank 633 if isinstance(index, slice): 634 635 if (index.step is not None) and (index.step != 1): 636 raise IndexError("Cannot stride through a shape") 637 start = index.start 638 stop = index.stop 639 if start is None: 640 start = 0 641 start = _fix_start_index(start, rank, self.num_row_partitions) 642 stop = _fix_stop_index(stop, rank) 643 return self._slice_shape(start, stop) 644 elif isinstance(index, int): 645 if index < 0: 646 if rank is None: 647 raise ValueError( 648 "Rank must be known to use __getitem__ with a negative index.") 649 return self._dimension(rank + index) 650 return self._dimension(index) 651 else: 652 raise TypeError("Argument is not an int or a slice") 653 654 def _num_elements(self): 655 """Number of elements in a shape. 656 657 Returns: 658 The number of elements in the shape. 659 660 """ 661 return math_ops.reduce_prod(self.inner_shape) 662 663 def _num_slices_in_dimension(self, axis): 664 """The total size of a dimension (like nvals). 665 666 Effectively, this is self[:axis+1]._num_elements() 667 668 Example: 669 shape = DynamicRaggedShape._from_inner_shape([2, 3, 4]) 670 shape._num_slices_in_dimension(0) = 2 671 shape._num_slices_in_dimension(1) = 6 672 shape._num_slices_in_dimension(2) = 24 673 shape._num_slices_in_dimension(-1) = 24 674 shape._num_slices_in_dimension(-2) = 6 675 shape._num_slices_in_dimension(-2) = 2 676 677 Args: 678 axis: the last axis to include in the number of elements. If negative, 679 then axis = axis + rank. 680 681 Returns: 682 The number of elements in the shape. 683 """ 684 if not isinstance(axis, int): 685 raise TypeError("axis must be an integer") 686 if axis < 0: 687 rank = self.rank 688 if rank is None: 689 raise ValueError( 690 "You can't use negative values if the rank is undefined") 691 axis = axis + rank 692 if axis == 0: 693 return self._dimension(0) 694 if axis <= self.num_row_partitions: 695 return self.row_partitions[axis - 1].nvals() 696 # If self.num_row_partitions = 1, and 697 # self.inner_shape=[3,5,6], and axis=2, then you want: 698 # 15 = 3 * 5 = math_ops.reduce_prod(self.inner_shape[:2]) 699 # 2 = axis - (self.num_row_partitions - 1) 700 # If num_row_partitions=0, and 701 # self.inner_shape=[3,5,6] and axis=2, then you want: 702 # 90 = 3 * 5 * 6 = math_ops.reduce_prod(self.inner_shape[:3]) 703 # 3 = axis - (self.num_row_partitions - 1) 704 remainder = axis - (self.num_row_partitions - 1) 705 return _reduce_prod_patch(self.inner_shape[:remainder]) 706 707 def is_uniform(self, axis): 708 """Returns true if the indicated dimension is uniform.""" 709 if not isinstance(axis, int): 710 raise TypeError("axis must be an integer") 711 rank = self.rank 712 if axis < 0: 713 raise IndexError("Negative axis values are not supported") 714 elif rank is not None and axis >= rank: 715 raise IndexError("Expected axis=%s < rank=%s" % (axis, rank)) 716 else: 717 return ((axis == 0 or axis > len(self._row_partitions)) # pylint:disable=superfluous-parens 718 or self._row_partitions[axis - 1].is_uniform()) 719 720 @property 721 def rank(self): 722 """The number of dimensions in this shape, or None if unknown.""" 723 inner_rank = self.inner_rank 724 if inner_rank is None: 725 return None 726 else: 727 return self.num_row_partitions + inner_rank 728 729 @property 730 def inner_shape(self): 731 """The inner dimension sizes for this shape. 732 733 Returns: 734 A 1-D integer `Tensor`. 735 """ 736 return self._inner_shape 737 738 @property 739 def inner_rank(self): 740 """The rank of inner_shape.""" 741 return tensor_shape.dimension_value(self._static_inner_shape.rank) 742 743 def _alt_inner_shape(self, new_inner_rank): 744 """Get an alternative inner shape with higher or lower rank. 745 746 For the rank of the inner shape to be be higher, the last few ragged 747 dimensions must have uniform_row_length. 748 749 Args: 750 new_inner_rank: the new rank of the inner_shape 751 752 Returns: 753 A new inner_shape of rank new_inner_rank. 754 """ 755 if new_inner_rank == 0: 756 raise ValueError("new_inner_rank cannot be zero") 757 elif self.inner_rank == 0: 758 raise ValueError("old inner_rank cannot be zero") 759 elif new_inner_rank == self.inner_rank: 760 return self.inner_shape 761 elif new_inner_rank < self.inner_rank: 762 if self._static_inner_shape.is_fully_defined(): 763 return _alt_inner_shape_from_tensor_shape(self._static_inner_shape, 764 self.dtype, new_inner_rank) 765 first_dimension = self._num_slices_in_dimension(-new_inner_rank) 766 if new_inner_rank == 1: 767 return array_ops.expand_dims(first_dimension, 0) 768 remaining_dimensions = self.inner_shape[1 - new_inner_rank:] 769 return array_ops.concat( 770 [array_ops.expand_dims(first_dimension, 0), remaining_dimensions], 771 axis=0) 772 else: 773 assert new_inner_rank > self.inner_rank 774 new_dimensions = new_inner_rank - self.inner_rank 775 if any( 776 [not x.is_uniform() for x in self.row_partitions[-new_dimensions:]]): 777 raise ValueError("Cannot get an inner shape over a ragged dimension") 778 first_dimension = self._num_slices_in_dimension(-new_inner_rank) 779 new_dimensions = new_inner_rank - self.inner_rank 780 new_dims = [first_dimension] + [ 781 x.uniform_row_length() for x in self.row_partitions[-new_dimensions:] 782 ] 783 return array_ops.concat([array_ops.stack(new_dims), self.inner_shape[1:]], 784 axis=0) 785 786 def _inner_shape_dim(self, dimension): 787 """Returns an int or a tensor representing _inner_shape[dimension].""" 788 result = tensor_shape.dimension_value(self._static_inner_shape[dimension]) 789 return self._inner_shape[dimension] if result is None else result 790 791 def _with_inner_rank(self, inner_rank): 792 """Returns the same shape but a different inner_rank. 793 794 All dimensions that are to be represented in the inner_shape must be dense. 795 See inner_rank. 796 797 Args: 798 inner_rank: the new inner_rank of the shape. 799 800 Returns: 801 the same shape but a different inner_rank 802 803 Raises: 804 ValueError if the new dense rank is invalid, or the old rank is unknown. 805 """ 806 rank = self.rank 807 if rank is None: 808 raise ValueError("Rank must be known to adjust inner_rank") 809 elif rank < 2: 810 if inner_rank == rank: 811 return self 812 raise ValueError("Cannot change inner_rank if rank < 2") 813 else: 814 # When self.rank is not None: 815 # self.rank = self.inner_rank + self.num_row_partitions 816 new_num_row_partitions = rank - inner_rank 817 return self._with_num_row_partitions(new_num_row_partitions) 818 819 def _with_num_row_partitions(self, num_row_partitions): 820 """Creates an identical shape with the given num_row_partitions. 821 822 Note that the shape must be statically refactorable to this rank. 823 In particular: 824 * rank must be known. 825 * num_row_partitions must be a nonnegative int. 826 * num_row_partitions must be less than the rank of the shape 827 * num_row_partitions must be greater or equal to the index of any ragged 828 dimension. 829 830 Note that if the num_row_partitions is the same, self is returned. 831 832 Args: 833 num_row_partitions: the target num_row_partitions (must be a nonnegative 834 int). 835 836 Returns: 837 a shape with a (possibly) different num_row_partitions. 838 839 Raises: 840 ValueError: if the rank is unknown, the argument is not a nonnegative int, 841 or there is a dimension that is nonuniform. 842 """ 843 rank = self.rank 844 if rank is None: 845 raise ValueError("Rank must be known to adjust num_row_partitions") 846 if not isinstance(num_row_partitions, int): 847 raise ValueError("num_row_partitions must be an int") 848 if num_row_partitions < 0: 849 raise ValueError("num_row_partitions must be nonnegative") 850 if num_row_partitions == self.num_row_partitions: 851 return self 852 if num_row_partitions >= rank: 853 raise ValueError("num_row_partitions must be less than rank") 854 if num_row_partitions > self.num_row_partitions: 855 num_row_partitions_diff = num_row_partitions - self.num_row_partitions 856 new_inner_rank = self.rank - num_row_partitions 857 nvals = self._inner_shape_dim(0) 858 more_rp = [] 859 for i in range(num_row_partitions_diff): 860 nrows = nvals 861 row_length = self._inner_shape_dim(i + 1) 862 nvals = nrows * row_length 863 rp = RowPartition.from_uniform_row_length( 864 row_length, nrows=nrows, dtype=self.dtype) 865 more_rp.append(rp) 866 alt_inner = self._alt_inner_shape(new_inner_rank) 867 return DynamicRaggedShape( 868 list(self.row_partitions) + more_rp, alt_inner) 869 else: 870 assert num_row_partitions < self.num_row_partitions 871 return DynamicRaggedShape( 872 self.row_partitions[:num_row_partitions], 873 self._alt_inner_shape(self.rank - num_row_partitions)) 874 875 def _merge_dims(self, outer_axis: int, 876 inner_axis: int) -> "DynamicRaggedShape": 877 """Merges outer_axis...inner_axis into a single dimension. 878 879 Returns a copy of this shape with the specified range of dimensions 880 flattened into a single dimension, with elements in row-major order. 881 882 #### Examples: 883 884 >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1), (1,2,3)])._merge_dims(0, 1) # pylint: disable=line-too-long 885 <DynamicRaggedShape lengths=[3, (1, 2, 3)] num_row_partitions=1> 886 >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1), (1,2,3)])._merge_dims(1, 2) # pylint: disable=line-too-long 887 <DynamicRaggedShape lengths=[2, (3, 3)] num_row_partitions=1> 888 >>> tf.experimental.DynamicRaggedShape.from_lengths([2, (2,1), (1,2,3)])._merge_dims(0, 2) # pylint: disable=line-too-long 889 <DynamicRaggedShape lengths=[6] num_row_partitions=0> 890 891 To mimic the behavior of `np.flatten` (which flattens all dimensions), use 892 `rt.merge_dims(0, -1). To mimic the behavior of `tf.layers.Flatten` (which 893 flattens all dimensions except the outermost batch dimension), use 894 `rt.merge_dims(1, -1)`. 895 896 Args: 897 outer_axis: `int`: The first dimension in the range of dimensions to 898 merge. May be negative if `self.shape.rank` is statically known. 899 inner_axis: `int`: The last dimension in the range of dimensions to merge. 900 May be negative if `self.shape.rank` is statically known. 901 902 Returns: 903 A copy of this shape, with the specified dimensions merged into a 904 single dimension. The returned shape will be 905 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` 906 is the total number of slices in the merged dimensions. 907 """ 908 outer_axis = array_ops.get_positive_axis( 909 outer_axis, 910 self.rank, 911 axis_name="outer_axis", 912 ndims_name="rank(self)") 913 inner_axis = array_ops.get_positive_axis( 914 inner_axis, 915 self.rank, 916 axis_name="inner_axis", 917 ndims_name="rank(self)") 918 if not outer_axis <= inner_axis: 919 raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or " 920 f"equal to inner_axis ({inner_axis}).") 921 if outer_axis == inner_axis: 922 return self 923 if self.num_row_partitions == 0: 924 # A dense tensor. 925 (new_inner_shape, new_static_inner_shape) = _merge_inner_shape( 926 self._inner_shape, self._static_inner_shape, outer_axis, inner_axis) 927 return DynamicRaggedShape([], 928 new_inner_shape, 929 dtype=self.dtype, 930 static_inner_shape=new_static_inner_shape) 931 if inner_axis <= self.num_row_partitions: 932 # Here, we are merging the row_partitions, 933 # but the inner_shape is unchanged. 934 if outer_axis == 0: 935 # There is no need to merge axes before the first, just truncate them. 936 return DynamicRaggedShape( 937 self._row_partitions[inner_axis:], 938 self.inner_shape, 939 dtype=self.dtype, 940 static_inner_shape=self._static_inner_shape) 941 prefix_rp = self._row_partitions[:outer_axis - 1] 942 suffix_rp = self._row_partitions[inner_axis:] 943 internal_rp = self._row_partitions[outer_axis - 1:inner_axis] 944 new_rp = prefix_rp + (_merge_row_partitions(internal_rp),) + suffix_rp 945 946 return DynamicRaggedShape( 947 new_rp, self.inner_shape, dtype=self.dtype, 948 static_inner_shape=self._static_inner_shape) 949 elif outer_axis > self.num_row_partitions: 950 # In this scenario, only the inner_shape is changed. 951 # Example #1: 952 # if [2, (1, 2), 5, 3], num_row_partitions=1, outer_axis=2, inner_axis=3. 953 # Result: [2, (1, 2), 15], num_row_partitions=1, outer_axis=2, 954 # inner_axis=3. 955 (new_inner_shape, new_static_inner_shape) = _merge_inner_shape( 956 self._inner_shape, self._static_inner_shape, 957 outer_axis-self.num_row_partitions, 958 inner_axis-self.num_row_partitions) 959 return DynamicRaggedShape( 960 self._row_partitions, 961 new_inner_shape, dtype=self.dtype, 962 static_inner_shape=new_static_inner_shape) 963 else: 964 # Here, both inner_shape and row_partitions are changed. 965 rank = self.rank 966 if rank is None: 967 raise ValueError("Cannot merge_dims of the inner shape if the " + 968 "dimension of inner_shape is unknown") 969 if outer_axis == 0: 970 new_inner_shape = self._alt_inner_shape(rank - inner_axis) 971 return DynamicRaggedShape._from_inner_shape(new_inner_shape) 972 else: 973 prefix = self._row_partitions[:outer_axis-1] 974 suffix = _merge_row_partitions(self._row_partitions[outer_axis-1:]) 975 new_inner_shape = self._alt_inner_shape(rank - inner_axis) 976 num_merged_inner = inner_axis - self.num_row_partitions 977 prod = _reduce_prod_patch(self._inner_shape[1:num_merged_inner + 1]) 978 tail_suffix = RowPartition.from_row_splits(suffix.row_splits() * prod) 979 return DynamicRaggedShape(prefix + (tail_suffix,), new_inner_shape) 980 981 def with_dtype(self, dtype): 982 """Change the dtype of the shape.""" 983 if dtype == self.dtype: 984 return self 985 else: 986 return DynamicRaggedShape( 987 self.row_partitions, self.inner_shape, dtype=dtype) 988 989 def _merge_with(self, other: "DynamicRaggedShape") -> "DynamicRaggedShape": 990 """Merge two shapes that are equal modulo num_row_partitions. 991 992 The resulting num_row_partitions is the maximum of the two 993 num_row_partitions. 994 995 Args: 996 other: a DynamicRaggedShape representing the same shape with a possibly 997 different number of row partitions. 998 999 Returns: 1000 A DynamicRaggedShape with the same shape and the maximum of the 1001 num_row_partitions of the two shapes. 1002 """ 1003 max_num_row_partitions = max(self.num_row_partitions, 1004 other.num_row_partitions) 1005 a = self._with_num_row_partitions(max_num_row_partitions) 1006 b = other._with_num_row_partitions(max_num_row_partitions) 1007 new_row_partitions = [ 1008 rp_a._merge_precomputed_encodings(rp_b) 1009 for (rp_a, rp_b) in zip(a._row_partitions, b._row_partitions) 1010 ] 1011 new_dtype = b.dtype if a.dtype == dtypes.int32 else dtypes.int64 1012 1013 new_static_inner_shape = a._static_inner_shape.merge_with( 1014 b._static_inner_shape) 1015 new_inner_shape = a._inner_shape 1016 return DynamicRaggedShape(new_row_partitions, new_inner_shape, new_dtype, 1017 True, new_static_inner_shape) 1018 1019 def _merge_with_spec( 1020 self, other: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape": 1021 """Merge a spec with a DynamicRaggedShape.""" 1022 # TODO(martinz): add tests for dynamic inconsistencies. 1023 max_num_row_partitions = max(self.num_row_partitions, 1024 other.num_row_partitions) 1025 a = self._with_num_row_partitions(max_num_row_partitions) 1026 b = other._with_num_row_partitions(max_num_row_partitions) 1027 new_row_partitions = [rp_a._merge_with_spec(rp_b) for (rp_a, rp_b) in 1028 zip(a._row_partitions, b._row_partitions)] 1029 new_dtype = b.dtype if a.dtype == dtypes.int32 else dtypes.int64 1030 1031 new_static_inner_shape = a._static_inner_shape.merge_with( 1032 b._static_inner_shape) 1033 new_inner_shape = a._inner_shape 1034 return DynamicRaggedShape( 1035 new_row_partitions, 1036 new_inner_shape, 1037 new_dtype, 1038 True, 1039 new_static_inner_shape) 1040 1041 def _as_row_partitions(self): 1042 """Returns row partitions representing this shape. 1043 1044 In order to represent a shape as row partitions, the rank of the shape 1045 must be known, and the shape must have rank at least one. 1046 1047 Returns: 1048 A list of RowPartition objects. 1049 Raises: 1050 ValueError, if the shape cannot be represented by RowPartitions. 1051 """ 1052 rank = self.rank 1053 if rank is None: 1054 raise ValueError("rank must be known for _as_row_partitions") 1055 elif rank < 1: 1056 raise ValueError("rank must be >= 1 for _as_row_partitions") 1057 fully_ragged = self._with_num_row_partitions(rank - 1) 1058 return fully_ragged.row_partitions 1059 1060 def _validate_flat_values_dynamically(self, flat_values): 1061 """Test if flat_values have the right nvals dynamically.""" 1062 if self.row_partitions: 1063 assert_op = check_ops.assert_equal( 1064 self.row_partitions[-1].nvals(), 1065 array_ops.shape(flat_values, out_type=self.dtype)[0], 1066 message="Last row partition does not match flat_values.") 1067 return control_flow_ops.with_dependencies([assert_op], flat_values) 1068 return flat_values 1069 1070 def _validate_flat_values(self, flat_values): 1071 """Test if flat_values have the right nvals.""" 1072 if not isinstance(flat_values, ops.Tensor): 1073 return flat_values 1074 if self.row_partitions: 1075 last_row_partition = self.row_partitions[-1] 1076 flat_values_shape = flat_values.shape 1077 if flat_values_shape is None: 1078 return self._validate_flat_values_dynamically(flat_values) 1079 first_dim_flat_values = flat_values_shape[0] 1080 if isinstance(first_dim_flat_values, tensor_shape.Dimension): 1081 first_dim_flat_values = first_dim_flat_values.value 1082 if first_dim_flat_values is None: 1083 return self._validate_flat_values_dynamically(flat_values) 1084 static_nvals = last_row_partition.static_nvals 1085 if static_nvals is None: 1086 return self._validate_flat_values_dynamically(flat_values) 1087 if first_dim_flat_values != static_nvals: 1088 raise ValueError("Last row partition does not match flat_values.") 1089 return flat_values 1090 1091 def _add_row_partitions(self, flat_values, validate=False): 1092 """Add row partitions to flat_values, if necessary. 1093 1094 If the shape is truly ragged, then this adds the row_partitions. 1095 1096 The shape is dense, then this just returns flat_values. 1097 1098 Args: 1099 flat_values: the flat_values of a ragged tensor with this shape, or a 1100 dense tensor with this shape. 1101 validate: validate the flat_values have the right first dimension. 1102 1103 Returns: 1104 flat_values reshaped to have row_partitions. 1105 """ 1106 if self.row_partitions: 1107 if validate: 1108 flat_values = self._validate_flat_values(flat_values) 1109 return ragged_tensor.RaggedTensor._from_nested_row_partitions( 1110 flat_values, self.row_partitions, validate=False) 1111 else: 1112 return flat_values 1113 1114 class Spec: 1115 """A Spec for DynamicRaggedShape: similar to a static shape.""" 1116 1117 def __init__(self, row_partitions: Tuple[RowPartitionSpec, ...], 1118 static_inner_shape: tensor_shape.TensorShape, 1119 dtype: dtypes.DType): 1120 """Create a Spec given row partitions, a static inner shape, and a dtype. 1121 1122 Args: 1123 row_partitions: A sequence of `RowPartitionSpec`s describing how the 1124 ragged shape is partitioned. 1125 static_inner_shape: The static shape of the flat_values. 1126 dtype: The DType used to encode the shape (tf.int64 or tf.int32). 1127 """ 1128 # Independent validation and coercion of each argument. 1129 if not isinstance(row_partitions, Iterable): 1130 raise TypeError("row_partitions should be an Iterable") 1131 1132 row_partitions = tuple(row_partitions) 1133 1134 static_inner_shape = tensor_shape.as_shape(static_inner_shape) 1135 1136 dtype = dtypes.as_dtype(dtype) 1137 1138 if not all(isinstance(rp, RowPartitionSpec) for rp in row_partitions): 1139 raise TypeError( 1140 "row_partitions should be an Iterable of RowPartitionSpecs") 1141 1142 if dtype != dtypes.int32 and dtype != dtypes.int64: 1143 raise ValueError("dtype must be tf.int32 or tf.int64") 1144 1145 # All fields are now typechecked and internally consistent. 1146 for spec in row_partitions: 1147 if spec.dtype != dtype: 1148 raise ValueError( 1149 f"dtype of {spec!r} is {spec.dtype!r}: expected {dtype!r}") 1150 1151 row_partitions = tuple(row_partitions) 1152 1153 inner_rank = static_inner_shape.rank 1154 1155 if inner_rank == 0: 1156 if row_partitions: 1157 raise ValueError( 1158 "If row_partitions are provided, must have inner_rank > 0") 1159 else: 1160 num_slices_in_dimension = [] # type: Sequence[tensor_shape.Dimension] 1161 1162 # We first attempt to calculate num_slices_in_dimension through a 1163 # forward pass, using nrows[k] = nrows[k-1] * uniform_row_length 1164 # and other tricks. 1165 for i in range(len(row_partitions)): 1166 rp = row_partitions[i] 1167 result = tensor_shape.Dimension(rp.nrows) 1168 if i > 0: 1169 previous_rp = row_partitions[i - 1] 1170 result = result.merge_with(previous_rp.nvals) 1171 result = result.merge_with(num_slices_in_dimension[-1] * 1172 previous_rp.uniform_row_length) 1173 num_slices_in_dimension.append(result) 1174 # In the last step of the forward pass, 1175 # we combine nvals and the first dimension in static_inner_shape. 1176 if row_partitions: 1177 last_rp = row_partitions[-1] 1178 result = (num_slices_in_dimension[-1] * 1179 last_rp.uniform_row_length).merge_with(last_rp.nvals) 1180 if inner_rank is not None: 1181 result = result.merge_with( 1182 tensor_shape.dimension_at_index(static_inner_shape, 0)) 1183 static_inner_shape = result + static_inner_shape[1:] 1184 num_slices_in_dimension.append(result) 1185 1186 # Now, we start a backward pass. 1187 for i in range(len(num_slices_in_dimension) - 1, 0, -1): 1188 num_slices_in_dimension[i - 1] = num_slices_in_dimension[ 1189 i - 1].merge_with( 1190 _safe_floor_div(num_slices_in_dimension[i], 1191 row_partitions[i - 1].uniform_row_length)) 1192 1193 # Finally, we construct the partitions. 1194 row_partitions = [ 1195 RowPartitionSpec( # pylint: disable=g-complex-comprehension 1196 nrows=num_slices_in_dimension[i].value, 1197 uniform_row_length=rp.uniform_row_length, 1198 nvals=num_slices_in_dimension[i + 1].value, 1199 dtype=rp.dtype) for i, rp in enumerate(row_partitions) 1200 ] 1201 1202 self._static_inner_shape = static_inner_shape 1203 self._inner_shape = tensor_spec.TensorSpec( 1204 [inner_rank], dtype=dtype) 1205 self._row_partitions = row_partitions 1206 1207 def __repr__(self): 1208 return ( 1209 f"DynamicRaggedShape.Spec(row_partitions={self._row_partitions!r}, " + 1210 f"static_inner_shape={self._static_inner_shape!r}, " + 1211 f"dtype={self.dtype!r})") 1212 1213 @classmethod 1214 def from_value(cls, value: Any) -> "DynamicRaggedShape.Spec": 1215 """Create a Spec from a DynamicRaggedShape.""" 1216 # super().from_value(...) creates an object, but there is no validation. 1217 # No methods can be trusted on the object, just the properties. 1218 initial = super(DynamicRaggedShape.Spec, cls).from_value(value) 1219 1220 # However, since value is a DynamicRaggedShape, we 1221 # can guarantee that initial._inner_shape.shape.rank == 1 1222 1223 # Moreover, if inner_shape.shape[0] is not None, then 1224 # static_inner_shape.rank is not None. 1225 1226 return DynamicRaggedShape.Spec( 1227 row_partitions=initial._row_partitions, 1228 static_inner_shape=initial._static_inner_shape, 1229 dtype=initial._inner_shape.dtype) 1230 1231 # TODO(martinz): it is unclear what the default uniformity of RowPartitions 1232 # should be, so I am moving this to experimental until we figure it out. 1233 # Also, while I have specified this is meant to represent a shape of a 1234 # proper Tensor instead of a RaggedTensor, this is also subject to 1235 # interpretation. 1236 @classmethod 1237 def _from_tensor_shape(cls, 1238 shape: Any, 1239 num_row_partitions: int, 1240 dtype: dtypes.DType) -> "DynamicRaggedShape.Spec": 1241 """Creates a `DynamicRaggedShape.Spec` corresponding to a `tf.TensorShape`. 1242 1243 It is assumed that this is a `tf.TensorShape` coming from a 1244 `tf.TensorSpec`, not from `RaggedTensor.shape`. 1245 1246 In addition to the shape, we need to know the number of row partitions, 1247 and the dtype used in the shape (tf.int32 or tf.int64). 1248 1249 Within the dimensions that are partitioned, all dimensions are assumed 1250 to be uniform. 1251 1252 Args: 1253 shape: a TensorShape. 1254 num_row_partitions: the ragged rank of the RaggedShape. 1255 dtype: the dtype of the shape (not the tensor); tf.int64 or tf.int32. 1256 1257 Returns: 1258 a DynamicRaggedShape.Spec representing a TensorShape. 1259 """ 1260 if dtype != dtypes.int32 and dtype != dtypes.int64: 1261 raise ValueError("dtype must be tf.int32 or tf.int64") 1262 1263 shape = tensor_shape.as_shape(shape) 1264 if shape.rank is None: 1265 row_partitions = [ 1266 RowPartitionSpec(dtype=dtype) for _ in range(num_row_partitions) 1267 ] 1268 return DynamicRaggedShape.Spec( 1269 row_partitions=row_partitions, 1270 static_inner_shape=tensor_shape.TensorShape(None), 1271 dtype=dtype) 1272 1273 if shape.rank <= 1: 1274 # Create a scalar or vector shape. 1275 if num_row_partitions: 1276 raise ValueError("num_row_partitions should be zero " + 1277 "if shape is a scalar or vector.") 1278 return DynamicRaggedShape.Spec( 1279 row_partitions=[], static_inner_shape=shape, dtype=dtype) 1280 1281 if shape.rank <= num_row_partitions: 1282 raise ValueError("num_row_partitions must be less than rank") 1283 1284 num_elements_so_far = tensor_shape.dimension_value(shape[0]) 1285 rp_specs = [] 1286 for i in range(num_row_partitions): 1287 current_dim = tensor_shape.dimension_value(shape[i + 1]) 1288 if current_dim is None or num_elements_so_far is None: 1289 nvals = None 1290 else: 1291 nvals = num_elements_so_far * current_dim 1292 rp_specs.append(RowPartitionSpec( 1293 nrows=num_elements_so_far, 1294 nvals=nvals, 1295 uniform_row_length=current_dim, 1296 dtype=dtype)) 1297 num_elements_so_far = nvals 1298 1299 static_inner_shape = tensor_shape.TensorShape( 1300 [num_elements_so_far]) + shape[num_row_partitions + 1:] 1301 return DynamicRaggedShape.Spec( 1302 row_partitions=rp_specs, 1303 static_inner_shape=static_inner_shape, 1304 dtype=dtype) 1305 1306 @classmethod 1307 def _from_spec( 1308 cls, 1309 spec: Union["DynamicRaggedShape.Spec", ragged_tensor.RaggedTensorSpec, 1310 tensor_spec.TensorSpec], 1311 dtype: dtypes.DType = dtypes.int64) -> "DynamicRaggedShape.Spec": 1312 """Create a TypeSpec for the shape of an object with a given TypeSpec. 1313 1314 I.e., if `x_spec = tf.type_spec_from_value(x)`, then 1315 `DynamicRaggedShape.from_spec(x_spec)` returns a TypeSpec compatible with 1316 `tf.type_spec_from_value(tf.shape(x))`. 1317 1318 >>> rt = tf.ragged.constant([[1, 2], [3], [4, 5, 6]]) 1319 >>> rt_spec = tf.type_spec_from_value(rt) 1320 >>> rt_shape = DynamicRaggedShape.from_tensor(rt) 1321 1322 >>> shape_spec_1 = tf.type_spec_from_value(rt_shape) 1323 >>> shape_spec_2 = DynamicRaggedShape.Spec._from_spec(rt_spec) 1324 >>> assert shape_spec_1.is_compatible_with(shape_spec_2) 1325 1326 Args: 1327 spec: a Spec of a Tensor or RaggedTensor. 1328 dtype: the default dtype (if necessary). 1329 1330 Returns: 1331 A Spec of the shape of a Tensor or RaggedTensor. 1332 1333 """ 1334 # TODO(martinz): Add StructuredTensor.Spec when its easy. 1335 if isinstance(spec, DynamicRaggedShape.Spec): 1336 return spec 1337 elif isinstance(spec, ragged_tensor.RaggedTensorSpec): 1338 return cls._from_tensor_shape(spec.shape, 1339 spec.ragged_rank, 1340 spec.row_splits_dtype) 1341 elif isinstance(spec, tensor_spec.TensorSpec): 1342 return cls._from_tensor_shape(shape=spec.shape, 1343 num_row_partitions=0, 1344 dtype=dtype) 1345 1346 @property 1347 def dtype(self) -> dtypes.DType: 1348 return self._inner_shape.dtype 1349 1350 @property 1351 def inner_rank(self) -> Optional[int]: 1352 if self._static_inner_shape.rank is not None: 1353 return self._static_inner_shape.rank 1354 if self._inner_shape.shape.rank is None: 1355 return None 1356 return tensor_shape.dimension_value(self._inner_shape.shape[0]) 1357 1358 @property 1359 def num_row_partitions(self) -> int: 1360 return len(self._row_partitions) 1361 1362 @property 1363 def rank(self) -> Optional[int]: 1364 inner_rank = self.inner_rank 1365 return None if inner_rank is None else inner_rank + self.num_row_partitions 1366 1367 def _dimension(self, index: int) -> Optional[int]: 1368 """Get the size of dimension index, if known statically.""" 1369 if index == 0: 1370 if self._row_partitions: 1371 return self._row_partitions[0].nrows 1372 elif self.inner_rank is None: 1373 return None 1374 elif self.inner_rank == 0: 1375 raise ValueError("Index out of range: 0.") 1376 else: 1377 return tensor_shape.dimension_value(self._static_inner_shape[0]) 1378 if index <= len(self._row_partitions): 1379 return self._row_partitions[index - 1].uniform_row_length 1380 1381 relative_index = index - self.num_row_partitions 1382 1383 if self.inner_rank is None: 1384 return None 1385 elif self.inner_rank <= relative_index: 1386 raise ValueError(f"Index out of range: {index}.") 1387 else: 1388 return tensor_shape.dimension_value( 1389 self._static_inner_shape[relative_index]) 1390 1391 def _num_slices_in_dimension(self, axis: int) -> Optional[int]: 1392 """The total size of a dimension (like nvals). 1393 1394 This is a static version of DynamicRaggedShape._num_slices_in_dimension() 1395 1396 Example: 1397 1398 ``` 1399 shape = DynamicRaggedShape.Spec( 1400 _row_partitions=[ 1401 RowPartitionSpec(nrows=3, nvals=14, dtype=tf.int32) 1402 RowPartitionSpec(nrows=14, nvals=25, dtype=tf.int32) 1403 1404 ], 1405 _static_inner_shape=tf.TensorShape([25, 3, 4]), 1406 _inner_shape=tf.TensorSpec(tf.TensorShape([3]), dtype=tf.int32)) 1407 shape._num_slices_in_dimension(0) = 3 1408 shape._num_slices_in_dimension(1) = 14 1409 shape._num_slices_in_dimension(2) = 25 1410 shape._num_slices_in_dimension(3) = 3 1411 shape._num_slices_in_dimension(4) = 4 1412 shape._num_slices_in_dimension(-2) = 3 1413 ``` 1414 1415 Args: 1416 axis: the last dimension to include. 1417 1418 Returns: 1419 the number of values in a dimension. 1420 """ 1421 if not isinstance(axis, int): 1422 raise TypeError("axis must be an integer") 1423 axis = array_ops.get_positive_axis(axis, self.rank, ndims_name="rank") 1424 1425 if axis == 0: 1426 return self._dimension(0) 1427 if axis <= self.num_row_partitions: 1428 # TODO(martinz): use nvals OR nrows, whichever is defined. 1429 return self._row_partitions[axis - 1].nvals 1430 remainder = axis - (self.num_row_partitions - 1) 1431 head_inner_shape = self._static_inner_shape[:remainder] 1432 return head_inner_shape.num_elements() 1433 1434 def with_dtype(self, dtype: dtypes.DType) -> "DynamicRaggedShape.Spec": 1435 """Return the same spec, but with a different DType.""" 1436 new_rp_specs = [rp.with_dtype(dtype) for rp in self._row_partitions] 1437 return DynamicRaggedShape.Spec( 1438 row_partitions=new_rp_specs, 1439 static_inner_shape=self._static_inner_shape, 1440 dtype=dtype) 1441 1442 def _merge_with( 1443 self, 1444 other: "DynamicRaggedShape.Spec") -> "DynamicRaggedShape.Spec": 1445 """Merges all information between two specs. 1446 1447 Specs are expected to represent the same information modulo 1448 num_row_partitons. 1449 1450 If the specs are of different ranks, then fail. 1451 1452 Args: 1453 other: another Spec of the same rank. 1454 1455 Returns: 1456 a Spec with the union of information. 1457 """ 1458 max_num_row_partitions = max(self.num_row_partitions, 1459 other.num_row_partitions) 1460 a = self._with_num_row_partitions(max_num_row_partitions) 1461 b = other._with_num_row_partitions(max_num_row_partitions) 1462 1463 new_rp = [ 1464 a._merge_with(b) 1465 for (a, b) in zip(a._row_partitions, b._row_partitions) 1466 ] 1467 1468 new_static_inner_shape = a._static_inner_shape.merge_with( 1469 b._static_inner_shape) 1470 1471 dtype = b.dtype if (a.dtype == dtypes.int32) else dtypes.int64 1472 1473 return DynamicRaggedShape.Spec( 1474 new_rp, new_static_inner_shape, dtype=dtype) 1475 1476 def _with_num_row_partitions( 1477 self, 1478 new_num_row_partitions: int) -> "DynamicRaggedShape.Spec": 1479 """Change the number of row partitions in the spec.""" 1480 rank = self.rank 1481 if rank is None: 1482 raise ValueError( 1483 "Changing num_row_partitions with unknown rank unsupported") 1484 if new_num_row_partitions > max(rank - 1, 0): 1485 raise ValueError("Number of row partitions too large") 1486 if new_num_row_partitions < 0: 1487 raise ValueError("Number of row partitions negative") 1488 if self.num_row_partitions == new_num_row_partitions: 1489 return self 1490 elif self.num_row_partitions < new_num_row_partitions: 1491 # TODO(martinz): Consider swapping. 1492 rp_delta = new_num_row_partitions - self.num_row_partitions 1493 tail_shape = DynamicRaggedShape.Spec._from_tensor_shape( 1494 self._static_inner_shape, rp_delta, self.dtype) 1495 return DynamicRaggedShape.Spec( 1496 row_partitions=self._row_partitions + tail_shape._row_partitions, 1497 static_inner_shape=tail_shape._static_inner_shape, 1498 dtype=self.dtype) 1499 else: 1500 assert self.num_row_partitions > new_num_row_partitions 1501 new_row_partitions = self._row_partitions[:new_num_row_partitions] 1502 last_row_partition = new_row_partitions[-1] 1503 old_row_partitions = self._row_partitions[new_num_row_partitions:] 1504 new_static_inner_shape = ( 1505 tensor_shape.TensorShape( 1506 [last_row_partition.nvals] + 1507 [x.uniform_row_length for x in old_row_partitions]) + 1508 self._static_inner_shape[1:]) 1509 return DynamicRaggedShape.Spec( 1510 new_row_partitions, new_static_inner_shape, self.dtype) 1511 1512 def _set_rank_if_unknown(self, new_rank: int) -> "DynamicRaggedShape.Spec": 1513 """Ensures this has a known rank at least new_rank.""" 1514 if new_rank is None: 1515 raise TypeError("new_rank is None, but expected int") 1516 if new_rank < 0: 1517 raise ValueError("Rank must be non-negative") 1518 current_rank = self.rank 1519 if current_rank is not None and current_rank < new_rank: 1520 raise ValueError( 1521 "Rank is {current_rank}, expected at least {new_rank}.".format( 1522 current_rank=current_rank, new_rank=new_rank)) 1523 1524 if current_rank is not None: 1525 return self 1526 1527 if self._row_partitions: 1528 new_inner_rank = max(new_rank - self.num_row_partitions, 1) 1529 first_dim = self._row_partitions[-1].nvals 1530 static_inner_shape = tensor_shape.TensorShape( 1531 [first_dim] + [None] * (new_inner_rank - 1)) 1532 else: 1533 static_inner_shape = tensor_shape.TensorShape([None] * new_rank) 1534 1535 return DynamicRaggedShape.Spec( 1536 row_partitions=self._row_partitions, 1537 static_inner_shape=static_inner_shape, 1538 dtype=self.dtype) 1539 1540 def _truncate(self, new_rank: int) -> "DynamicRaggedShape.Spec": 1541 """Truncate a ragged shape spec. 1542 1543 For example, if the original spec s was for a shape: 1544 [3, [4, 1], 2, 7] 1545 1546 Then truncate_dynamic_ragged_shape_spec(s, 3) is a spec for: 1547 [3, [4, 1], 2] 1548 1549 Args: 1550 new_rank: the new rank 1551 1552 Returns: 1553 A truncated DynamicRaggedShape.Spec. 1554 """ 1555 if self.rank is None: 1556 return self._set_rank_if_unknown(new_rank)._truncate(new_rank) 1557 1558 if new_rank == 0: 1559 return DynamicRaggedShape.Spec._from_tensor_shape([], 0, self.dtype) 1560 1561 if new_rank == 1: 1562 vector_size = self._dimension(0) 1563 return DynamicRaggedShape.Spec._from_tensor_shape([vector_size], 0, 1564 self.dtype) 1565 1566 if new_rank < self.num_row_partitions + 1: 1567 new_row_partitions = self._row_partitions[:new_rank - 1] 1568 new_static_inner_shape = tensor_shape.TensorShape( 1569 [new_row_partitions[-1].nvals]) 1570 return DynamicRaggedShape.Spec( 1571 row_partitions=new_row_partitions, 1572 static_inner_shape=new_static_inner_shape, 1573 dtype=self.dtype) 1574 else: 1575 remainder = new_rank - self.num_row_partitions 1576 new_static_inner_shape = self._static_inner_shape[:remainder] 1577 return DynamicRaggedShape.Spec( 1578 row_partitions=self._row_partitions, 1579 static_inner_shape=new_static_inner_shape, 1580 dtype=self.dtype) 1581 1582 def _to_tensor_shape(self): 1583 """Get a tensor shape corresponding to this type.""" 1584 alt = self 1585 if alt._static_inner_shape.rank is None: 1586 return tensor_shape.TensorShape(None) 1587 if alt._static_inner_shape.rank == 0: 1588 assert not alt._row_partitions 1589 return alt._static_inner_shape 1590 prefix = [alt._dimension(0)] 1591 prefix.extend([rp.uniform_row_length for rp in alt._row_partitions]) 1592 suffix = alt._static_inner_shape[1:] 1593 return tensor_shape.TensorShape(prefix) + suffix 1594 1595 1596def broadcast_dynamic_shape(shape_x: DynamicRaggedShape, 1597 shape_y: DynamicRaggedShape) -> DynamicRaggedShape: 1598 """Returns the shape formed by broadcasting two shapes to be compatible. 1599 1600 1. If shape_x and shape_y both have row_partitions, then fail if their dtypes 1601 don't match. 1602 2. If neither has row_partitions and they have different dtypes, 1603 go with int64. 1604 3. If one has row_partitions, go with that dtype. 1605 1606 Args: 1607 shape_x: A `DynamicRaggedShape` 1608 shape_y: A `DynamicRaggedShape` 1609 1610 Returns: 1611 A `DynamicRaggedShape`. 1612 Raises: 1613 ValueError: If `shape_x` and `shape_y` are not broadcast-compatible. 1614 """ 1615 if not isinstance(shape_x, DynamicRaggedShape): 1616 raise TypeError("shape_x must be a DynamicRaggedShape") 1617 if not isinstance(shape_y, DynamicRaggedShape): 1618 raise TypeError("shape_y must be a DynamicRaggedShape") 1619 1620 return broadcast_dynamic_shape_extended(shape_x, shape_y)[0] 1621 1622 1623def broadcast_to(rt_input, shape: DynamicRaggedShape): 1624 """Broadcasts a potentially ragged tensor to a ragged shape. 1625 1626 Tiles `rt_input` as necessary to match the given shape. 1627 1628 Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`. 1629 1630 Args: 1631 rt_input: The potentially ragged tensor to broadcast. 1632 shape: A `DynamicRaggedShape` 1633 1634 Returns: 1635 A potentially ragged tensor whose values are taken from 1636 `rt_input`, and whose shape matches `shape`. 1637 """ 1638 if not isinstance(shape, DynamicRaggedShape): 1639 raise TypeError("shape must be a DynamicRaggedShape") 1640 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) 1641 origin_shape = None 1642 if ragged_tensor.is_ragged(rt_input): 1643 if shape.num_row_partitions != 0: 1644 if rt_input.row_splits.dtype != shape.dtype: 1645 raise ValueError("Cannot coerce row_splits.dtype") 1646 else: 1647 shape = shape.with_dtype(rt_input.row_splits.dtype) 1648 origin_shape = DynamicRaggedShape.from_tensor(rt_input) 1649 else: 1650 if shape.num_row_partitions != 0: 1651 origin_shape = DynamicRaggedShape.from_tensor(rt_input, dtype=shape.dtype) 1652 else: 1653 origin_shape = DynamicRaggedShape.from_tensor(rt_input, 1654 dtype=dtypes.int64) 1655 shape = shape.with_dtype(dtype=dtypes.int64) 1656 1657 broadcaster = _get_broadcaster(origin_shape, shape) 1658 return broadcaster.broadcast(rt_input) 1659 1660 1661def broadcast_dynamic_shape_extended( 1662 a: DynamicRaggedShape, b: DynamicRaggedShape 1663): # -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster] 1664 """Gets the smallest shape to which a and b can broadcast. 1665 1666 In order to create the smallest shape, one must also do most of the 1667 work to figure out how to transform from the shapes given. Thus, in addition 1668 to returning the shape, it also creates transformations from the 1669 original shapes to the result. 1670 1671 This is the equivalent of: 1672 1673 c = broadcast_dynamic_shape(a, b) 1674 ac = get_broadcaster(a, c) 1675 bc = get_broadcaster(b, c) 1676 return (c, ac, bc) 1677 1678 Args: 1679 a: a DynamicRaggedShape 1680 b: a DynamicRaggedShape 1681 1682 Returns: 1683 A triple of a shape and two broadcasters. 1684 """ 1685 if a.row_partitions and b.row_partitions: 1686 if a.dtype != b.dtype: 1687 raise ValueError("Dtypes don't match") 1688 elif a.dtype != b.dtype: 1689 if a.row_partitions: 1690 b = b.with_dtype(a.dtype) 1691 elif b.row_partitions: 1692 a = a.with_dtype(b.dtype) 1693 else: 1694 a = a.with_dtype(dtypes.int64) 1695 b = b.with_dtype(dtypes.int64) 1696 1697 if (a.rank is None or b.rank is None): 1698 raise ValueError("Unable to broadcast: unknown rank") 1699 elif a.rank == 0: 1700 return (b, _Broadcaster(a, b, []), _get_identity_broadcaster(b)) 1701 elif b.rank == 0: 1702 return (a, _get_identity_broadcaster(a), _Broadcaster(b, a, [])) 1703 elif a.rank == 1 and b.rank == 1: 1704 [a_layer, b_layer, 1705 target] = _broadcast_dynamic_shape_one_layer(a.inner_shape, b.inner_shape) 1706 target_shape = DynamicRaggedShape._from_inner_shape(target) # pylint: disable=protected-access 1707 return (target_shape, _Broadcaster(a, target_shape, [a_layer]), 1708 _Broadcaster(b, target_shape, [b_layer])) 1709 1710 if a.rank > b.rank: 1711 (c, bc, ac) = _broadcast_dynamic_shape_extended_helper(b, a) # pylint: disable=arguments-out-of-order 1712 1713 return (c, ac, bc) 1714 1715 return _broadcast_dynamic_shape_extended_helper(a, b) 1716 1717 1718def _row_partitions_identical(shape_a, shape_b): 1719 """Returns True iff all row_partitions in shapes are identical.""" 1720 return ((shape_a.num_row_partitions == shape_b.num_row_partitions) and all( 1721 a is b for a, b in zip(shape_a.row_partitions, shape_b.row_partitions))) 1722 1723 1724# TODO(martinz): Preserve shapes better (see CL/414806185) 1725@dispatch.dispatch_for_binary_elementwise_apis(ragged_tensor.RaggedOrDense, 1726 ragged_tensor.RaggedOrDense) 1727def ragged_binary_elementwise_op_impl(op, x, y): 1728 """Binary elementwise api handler for RaggedTensors.""" 1729 x_is_ragged = ragged_tensor.is_ragged(x) 1730 y_is_ragged = ragged_tensor.is_ragged(y) 1731 1732 # Convert args to tensors. 1733 x = ragged_tensor.convert_to_tensor_or_ragged_tensor( 1734 x, preferred_dtype=(y.dtype if y_is_ragged else None)) 1735 y = ragged_tensor.convert_to_tensor_or_ragged_tensor( 1736 y, preferred_dtype=x.dtype) 1737 1738 if x_is_ragged and y_is_ragged: 1739 x, y = ragged_tensor.match_row_splits_dtypes(x, y) 1740 1741 if ((x_is_ragged and y_is_ragged) or 1742 (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or 1743 (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): 1744 shape_x = DynamicRaggedShape.from_tensor(x) 1745 shape_y = DynamicRaggedShape.from_tensor(y) 1746 if shape_x.dtype != shape_y.dtype: 1747 if not x_is_ragged: 1748 shape_x = shape_x.with_dtype(shape_y.dtype) 1749 elif not y_is_ragged: 1750 shape_y = shape_y.with_dtype(shape_x.dtype) 1751 1752 if _row_partitions_identical(shape_x, shape_y): 1753 # At this point, both x and y must be ragged. 1754 return shape_x._add_row_partitions( # pylint: disable=protected-access 1755 op(x.flat_values, y.flat_values), validate=False) 1756 1757 (shape_z, bcast_xz, 1758 bcast_yz) = broadcast_dynamic_shape_extended(shape_x, shape_y) 1759 x_new_flat = bcast_xz.broadcast_flat_values(x, inner_dimensions=False) 1760 y_new_flat = bcast_yz.broadcast_flat_values(y, inner_dimensions=False) 1761 z_flat = op(x_new_flat, y_new_flat) 1762 return shape_z._add_row_partitions(z_flat, validate=True) # pylint: disable=protected-access 1763 1764 x_values = x.flat_values if ragged_tensor.is_ragged(x) else x 1765 y_values = y.flat_values if ragged_tensor.is_ragged(y) else y 1766 mapped_values = op(x_values, y_values) 1767 if isinstance(mapped_values, bool): 1768 return mapped_values # Special case for tensor_equals. 1769 if ragged_tensor.is_ragged(x): 1770 return x.with_flat_values(mapped_values) 1771 else: 1772 return y.with_flat_values(mapped_values) 1773 1774 1775@dispatch.dispatch_for_binary_elementwise_assert_apis( 1776 ragged_tensor.RaggedOrDense, ragged_tensor.RaggedOrDense) 1777def ragged_binary_elementwise_assert_op_impl(op, x, y): 1778 """Binary elementwise assert api handler for RaggedTensors. 1779 1780 This handles binary assert operations for ragged tensors. Compared with 1781 `ragged_binary_elementwise_op_impl`, this handler does not compute a ragged 1782 tensor as output. Instead, it applies the assert operation `op` to input 1783 tensors based on their ragged shapes and flat_values, and returns the result 1784 of the assertion operation. 1785 1786 Args: 1787 op: a binary assert operation on Tensors. 1788 x: something that can be coerced to a Tensor or RaggedTensor. 1789 y: something that can be coerced to a Tensor or RaggedTensor. 1790 1791 Returns: 1792 the result of the assertion operation. 1793 1794 """ 1795 x_is_ragged = ragged_tensor.is_ragged(x) 1796 y_is_ragged = ragged_tensor.is_ragged(y) 1797 1798 # Convert args to tensors. 1799 x = ragged_tensor.convert_to_tensor_or_ragged_tensor( 1800 x, preferred_dtype=(y.dtype if y_is_ragged else None)) 1801 y = ragged_tensor.convert_to_tensor_or_ragged_tensor( 1802 y, preferred_dtype=x.dtype) 1803 1804 if x_is_ragged and y_is_ragged: 1805 x, y = ragged_tensor.match_row_splits_dtypes(x, y) 1806 1807 if ((x_is_ragged and y_is_ragged) or 1808 (x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or 1809 (y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)): 1810 shape_x = DynamicRaggedShape.from_tensor(x) 1811 shape_y = DynamicRaggedShape.from_tensor(y) 1812 if shape_x.dtype != shape_y.dtype: 1813 if not x_is_ragged: 1814 shape_x = shape_x.with_dtype(shape_y.dtype) 1815 elif not y_is_ragged: 1816 shape_y = shape_y.with_dtype(shape_x.dtype) 1817 1818 if _row_partitions_identical(shape_x, shape_y): 1819 # At this point, both x and y must be ragged. 1820 return op(x.flat_values, y.flat_values) 1821 1822 (_, bcast_xz, bcast_yz) = broadcast_dynamic_shape_extended(shape_x, shape_y) 1823 x_new_flat = bcast_xz.broadcast_flat_values(x, inner_dimensions=False) 1824 y_new_flat = bcast_yz.broadcast_flat_values(y, inner_dimensions=False) 1825 return op(x_new_flat, y_new_flat) 1826 1827 x_values = x.flat_values if ragged_tensor.is_ragged(x) else x 1828 y_values = y.flat_values if ragged_tensor.is_ragged(y) else y 1829 return op(x_values, y_values) 1830 1831 1832def _find_dtype_helper(value, preferred): 1833 """Helper for _find_dtype.""" 1834 if preferred is not None: 1835 return preferred 1836 elif isinstance(value, RowPartition): 1837 return value.dtype 1838 elif isinstance(value, dtypes.DType): 1839 return value 1840 elif isinstance(value, int): 1841 return None 1842 elif isinstance(value, list): 1843 return None 1844 elif isinstance(value, tuple): 1845 return None 1846 elif isinstance(value, core.Tensor): 1847 return value.dtype 1848 return value.dtype 1849 1850 1851def _find_dtype(value, preferred): 1852 """Returns the preferred dtype of value or preferred if preferred != None. 1853 1854 This is used as an operator to pass over multiple objects in decreasing order 1855 of priority until there is a preferred dtype for one. For example, if you were 1856 adding three tensor-ish things (some tensors, some lists), and needed a 1857 preferred dtype, you could use this as: 1858 1859 def adding(a, b, c, dtype = None): 1860 dtype = _find_dtype(a, dtype) 1861 dtype = _find_dtype(b, dtype) 1862 dtype = _find_dtype(c, dtype) 1863 if dtype is None: 1864 dtype = tf.float32 1865 ...Code continues here... 1866 1867 Args: 1868 value: a list, value, RowPartition, or tensor. 1869 preferred: a given dtype. If not None, this will be returned. 1870 1871 Returns: 1872 an optional dtype. 1873 """ 1874 result = _find_dtype_helper(value, preferred) 1875 if (result == dtypes.int64 or result == dtypes.int32 or result is None): 1876 return result 1877 raise ValueError("Illegal dtype: " + str(result)) 1878 1879 1880def _find_dtype_iterable( 1881 iterable: Iterable[Any], 1882 dtype: Optional[dtypes.DType]) -> Optional[dtypes.DType]: 1883 """Find the preferred dtype of a list of objects. 1884 1885 This will go over the iterable, and use the first object with a preferred 1886 dtype. The dtype passed has highest priority if it is not None. 1887 1888 Args: 1889 iterable: an iterable with things that might have a dtype. 1890 dtype: an overriding dtype, or None. 1891 1892 Returns: 1893 an optional dtype. 1894 """ 1895 if dtype is not None: 1896 return dtype 1897 for x in iterable: 1898 dtype = _find_dtype(x, dtype) 1899 return dtype 1900 1901 1902class _LayerBroadcaster(abc.ABC): 1903 """A broadcaster of a single layer. 1904 1905 Although this class does not literally contain a gather_index, the reference 1906 implementation is defined through a gather_index. Thus, any subclasses should 1907 first define the gather_index property. Other functions can be overridden 1908 for optimization, but it should not change the behavior. 1909 """ 1910 1911 @property 1912 @abc.abstractmethod 1913 def gather_index(self): 1914 """Returns a 1D tensor. 1915 1916 The size of the 1D tensor is equal to the destination size. 1917 1918 The ith element of the result is the index of the source of the ith element. 1919 """ 1920 pass 1921 1922 @property 1923 def dtype(self): 1924 """Returns the dtype of the broadcast.""" 1925 return self.gather_index.dtype 1926 1927 @abc.abstractmethod 1928 def with_dtype(self, dtype): 1929 """Returns an identical _LayerBroadcaster with a different dtype.""" 1930 pass 1931 1932 def __repr__(self): 1933 return str(self.gather_index) 1934 1935 @classmethod 1936 def from_gather_index(cls, gather_index): 1937 """Create a broadcaster from a gather_index.""" 1938 return _GatherLayerBroadcaster(gather_index) 1939 1940 @classmethod 1941 def first_layer(cls, nrows_source, nrows_target): 1942 """Create a broadcaster from a gather_index.""" 1943 gather_index = _first_layer_gather_index(nrows_source, nrows_target) 1944 return _LayerBroadcaster.from_gather_index(gather_index) 1945 1946 @classmethod 1947 def get_singleton_broadcaster(cls, target_size): 1948 """Broadcast from 1 element to target_size elements.""" 1949 return _LayerBroadcaster.from_gather_index( 1950 array_ops.zeros(target_size, dtype=target_size.dtype)) 1951 1952 @abc.abstractmethod 1953 def with_dependencies(self, checks): 1954 """Add dependencies to a _LayerBroadcaster. 1955 1956 Args: 1957 checks: a list of ops that need to be run before any tensors from the 1958 Broadcaster are used. 1959 1960 Returns: 1961 a copy of this _LayerBroadcaster with dependencies added. 1962 """ 1963 pass 1964 1965 @classmethod 1966 def get_identity_broadcaster(cls, nvals, dtype=None): 1967 """Create an identity broadcaster. 1968 1969 TODO(martinz): an identity broadcaster can be far more efficient than a 1970 generic broadcaster. Add an optimized implementation. 1971 Args: 1972 nvals: the number of values for the broadcaster. 1973 dtype: the dtype of the broadcaster, or None to use the dtype of nvals. 1974 Returns: 1975 an identity broadcaster from [0....nvals-1] to [0...nvals-1] 1976 """ 1977 return _GatherLayerBroadcaster(math_ops.range(nvals, dtype=dtype)) 1978 1979 def broadcast_tensor(self, tensor): 1980 """Broadcast from a dense tensor. 1981 1982 It is assumed that the first axis of the dense tensor is indexed by the 1983 source shape, and at the end, the first axis of the dense tensor is 1984 indexed by the destination shape. 1985 1986 Args: 1987 tensor: a dense tensor. 1988 1989 Returns: 1990 A dense tensor. 1991 """ 1992 return array_ops.gather(tensor, self.gather_index) 1993 1994 def dest_nrows(self): 1995 """Return the number of rows in the resulting gather, or None if tiling.""" 1996 return math_ops.cast( 1997 array_ops.shape(self.gather_index)[0], dtype=self.dtype) 1998 1999 def broadcast_row_partition(self, rp): 2000 """Return a new shape where the rows are broadcasted. 2001 2002 *--self--->* 2003 | | 2004 rp result 2005 | | 2006 V V 2007 *--------->* 2008 2009 This is equivalent to: 2010 return RowPartition.from_row_lengths(self.broadcast(rp.row_lengths())) 2011 2012 However, if the shape has uniform row length, then that property is 2013 maintained. 2014 2015 Args: 2016 rp: a row partition. 2017 2018 Returns: 2019 a RowPartition representing a broadcast version of this row partition. 2020 """ 2021 if not rp.is_uniform(): 2022 return RowPartition.from_row_lengths( 2023 self.broadcast_tensor(rp.row_lengths())) 2024 else: 2025 return RowPartition.from_uniform_row_length( 2026 rp.uniform_row_length(), 2027 nvals=rp.uniform_row_length() * self.dest_nrows(), 2028 nrows=self.dest_nrows()) 2029 2030 def next_layer(self, original_rp, broadcast_rp): 2031 r"""Create the next layer gather_index whether or not a broadcast happens. 2032 2033 *---------self------->* 2034 | | 2035 original_rp broadcast_rp 2036 | | 2037 \|/ \|/ 2038 *--next_broadcaster-->* 2039 Args: 2040 original_rp: the original row partition. 2041 broadcast_rp: the target row partition. 2042 2043 Returns: 2044 the gather_index for next_broadcaster. 2045 2046 """ 2047 gather_index = _next_layer_gather_index(self, original_rp, broadcast_rp) 2048 return _LayerBroadcaster.from_gather_index(gather_index) 2049 2050 2051class _GatherLayerBroadcaster(_LayerBroadcaster): 2052 """Implements _LayerBroadcaster with an explicit gather_index. 2053 2054 For example, suppose that the source shape is: 2055 [*],[*,*] 2056 And the target shape is: 2057 [*],[*,*],[*],[*,*] 2058 Then, this can be represented with a map: 2059 [0,1,2,0,1,2] 2060 2061 """ 2062 2063 def __init__(self, gather_index): 2064 gather_index = ops.convert_to_tensor(gather_index) 2065 if (gather_index.dtype != dtypes.int64 and 2066 gather_index.dtype != dtypes.int32): 2067 raise ValueError("gather_index must be int64 or int32") 2068 self._gather_index = gather_index 2069 2070 @property 2071 def gather_index(self): 2072 return self._gather_index 2073 2074 def with_dtype(self, dtype): 2075 return _GatherLayerBroadcaster(math_ops.cast(self._gather_index, dtype)) 2076 2077 def with_dependencies(self, checks): 2078 new_gather_index = control_flow_ops.with_dependencies( 2079 checks, self._gather_index) 2080 return _GatherLayerBroadcaster(new_gather_index) 2081 2082 2083class _Broadcaster: 2084 """A _Broadcaster represents a transformation from one shape to another. 2085 2086 It provides a transform for each axis of the source shape to the 2087 corresponding axis of the destination shape. 2088 2089 """ 2090 2091 def __init__(self, 2092 source_shape, 2093 target_shape, 2094 layer_broadcasters, 2095 dtype=None): 2096 """Create a broadcaster. 2097 2098 Do not call directly. 2099 The source_shape, target_shape, and layer_broadcasters are converted 2100 to have the same dtype. 2101 2102 Note: source_shape.rank and target_shape.rank must be known. 2103 Args: 2104 source_shape: the source DynamicRaggedShape 2105 target_shape: the target DynamicRaggedShape 2106 layer_broadcasters: List[_LayerBroadcaster] of length source_shape.rank. 2107 dtype: the preferred dtype of the broadcaster. 2108 2109 Raises: 2110 TypeError: if the input types don't match. 2111 """ 2112 if not isinstance(source_shape, DynamicRaggedShape): 2113 raise TypeError("source_shape is not a DynamicRaggedShape") 2114 if not isinstance(target_shape, DynamicRaggedShape): 2115 raise TypeError("target_shape is not a DynamicRaggedShape") 2116 if not isinstance(layer_broadcasters, list): 2117 raise TypeError("layer_broadcasters not a list: " + 2118 str(layer_broadcasters)) 2119 for bc in layer_broadcasters: 2120 if not isinstance(bc, _LayerBroadcaster): 2121 raise TypeError("Not a LayerBroadcaster: " + str(bc)) 2122 2123 dtype = _find_dtype(source_shape, dtype) 2124 dtype = _find_dtype(target_shape, dtype) 2125 dtype = _find_dtype_iterable(layer_broadcasters, dtype) 2126 dtype = _find_dtype(dtypes.int64, dtype) 2127 self._source_shape = source_shape.with_dtype(dtype) 2128 self._target_shape = target_shape.with_dtype(dtype) 2129 self._layer_broadcasters = [x.with_dtype(dtype) for x in layer_broadcasters] 2130 2131 def __repr__(self): 2132 return ("{src_shape:" + str(self._source_shape) + ", target_shape:" + 2133 str(self._target_shape) + " layer_broadcasters: " + 2134 str(self._layer_broadcasters) + "}") 2135 2136 def with_dtype(self, dtype): 2137 """Return a copy of this Broadcaster with a different dtype.""" 2138 return _Broadcaster(self._source_shape, self._target_shape, 2139 self._layer_broadcasters, dtype) 2140 2141 @property 2142 def source_shape(self): 2143 return self._source_shape 2144 2145 @property 2146 def target_shape(self): 2147 return self._target_shape 2148 2149 @property 2150 def dtype(self): 2151 return self._source_shape.dtype 2152 2153 def _target_inner_shape_int32(self): 2154 new_inner_shape = self.target_shape.inner_shape 2155 if new_inner_shape.dtype == dtypes.int64: 2156 new_inner_shape = math_ops.cast(new_inner_shape, dtype=dtypes.int32) 2157 return new_inner_shape 2158 2159 # pylint:disable=protected-access 2160 def broadcast_flat_values(self, rt, inner_dimensions=True): 2161 """flat_values of a ragged tensor broadcast to target_shape. 2162 2163 If inner_dimensions==True, then the result is a dense tensor with shape 2164 target_shape.inner_shape, the flat values of the broadcasted shape. 2165 2166 If you add target_shape.row_partitions, you will get the full broadcasted 2167 shape. 2168 2169 If inner_dimensions==False, the result is a dense tensor that satsifies 2170 certain properties: 2171 1. broadcast_to(result, target_shape.inner_shape) will give the result 2172 if inner_dimensions==True. 2173 2. Either (a) (result.rank < target_shape.inner_rank) 2174 or (b) (result.shape[0] == target_shape.inner_shape[0]). 2175 3. result.rank = min(target_shape.inner_rank, rt.rank) 2176 4. For i < target_shape.inner_rank - 1, and i < rt.rank, 2177 and if rt.shape[-i]!=1, then result.shape[-i]=target_shape[-i]. 2178 Args: 2179 rt: a ragged or dense tensor. 2180 inner_dimensions: if true, broadcast the inner dimensions as well. 2181 2182 Returns: 2183 a dense tensor 2184 """ 2185 if ragged_tensor.is_ragged(rt): 2186 rt = rt.flat_values 2187 # If rt was a regular tensor, it is its own flat_values. 2188 if self.target_shape.rank == 0: 2189 return rt 2190 inner_rank = self.target_shape.inner_rank 2191 if inner_rank > self._source_shape.rank: 2192 # The dense rank is larger than the whole shape. So, we make the shape 2193 # dense. 2194 if self.source_shape.num_row_partitions > 0: 2195 rt = array_ops.reshape( 2196 rt, self.source_shape._alt_inner_shape(self.source_shape.rank)) 2197 # rt.rank == self._source_shape.rank < inner_rank 2198 # Here, property 2a holds. 2199 if inner_dimensions: 2200 return array_ops.broadcast_to(rt, self._target_inner_shape_int32()) 2201 return rt 2202 else: 2203 if self._source_shape.inner_rank != inner_rank: 2204 rt = array_ops.reshape(rt, 2205 self._source_shape._alt_inner_shape(inner_rank)) # pylint:disable=protected-access 2206 # After the reshape, rt is flat_values with inner_rank. 2207 flat_broadcaster = self._layer_broadcasters[-inner_rank] 2208 rt = flat_broadcaster.broadcast_tensor(rt) 2209 # Here, property 2b holds. 2210 if inner_dimensions: 2211 rt = array_ops.broadcast_to(rt, self._target_inner_shape_int32()) 2212 return rt 2213 2214 def broadcast(self, rt): 2215 """Broadcast a tensor of source_shape to target_shape.""" 2216 flat_values = self.broadcast_flat_values(rt) 2217 return self.target_shape._add_row_partitions(flat_values) # pylint:disable=protected-access 2218 2219 2220def _get_layer_broadcasters_from_rps(zero_broadcaster, source_rps, target_rps): 2221 """Get LayerBroadcasters from RowPartitions. 2222 2223 *--zero_broadcaster->* 2224 | | 2225 source_rps[0] target_rps[0] 2226 | | 2227 V V 2228 *---result[1]------->* 2229 | | 2230 source_rps[1] target_rps[1] 2231 | | 2232 V V 2233 *---result[2]------->* 2234 . 2235 . 2236 . 2237 *---result[k-1]----->* 2238 | | 2239 source_rps[k] target_rps[k] 2240 | | 2241 V V 2242 *---result[k]------->* 2243 2244 Note: result[0] = zero_broadcaster 2245 2246 Args: 2247 zero_broadcaster: a broadcaster between the source and target row 2248 partitions' rows, and equal to result[0]. 2249 source_rps: source row partitions. 2250 target_rps: target row partitions (same length as source_rps). 2251 2252 Returns: 2253 result: a list of LayerBroadcasters. 2254 """ 2255 if not isinstance(zero_broadcaster, _LayerBroadcaster): 2256 raise TypeError("Not a _LayerBroadcaster: " + str(zero_broadcaster)) 2257 assert len(source_rps) == len(target_rps) 2258 if not source_rps: 2259 return [zero_broadcaster] 2260 next_broadcaster = zero_broadcaster.next_layer(source_rps[0], target_rps[0]) 2261 tail_broadcasters = _get_layer_broadcasters_from_rps(next_broadcaster, 2262 source_rps[1:], 2263 target_rps[1:]) 2264 return [zero_broadcaster] + tail_broadcasters 2265 2266 2267def _get_broadcaster(source_shape, target_shape): 2268 """Get a _Broadcaster from source_shape to target_shape.""" 2269 if source_shape.dtype != target_shape.dtype: 2270 raise ValueError("The source and target row_split dtypes should be equal") 2271 2272 if (source_shape.rank is None or target_shape.rank is None): 2273 raise ValueError("Rank of source and target must be statically known") 2274 elif source_shape.rank > target_shape.rank: 2275 raise ValueError("Cannot broadcast to a shape with smaller rank") 2276 elif source_shape.rank == 0: 2277 return _Broadcaster(source_shape, target_shape, []) 2278 elif target_shape.rank == 1: 2279 assert source_shape.rank == 1 2280 layer = _LayerBroadcaster.first_layer(source_shape.inner_shape[0], 2281 target_shape.inner_shape[0]) 2282 return _Broadcaster(source_shape, target_shape, [layer]) 2283 2284 assert source_shape.rank <= target_shape.rank 2285 assert target_shape.rank >= 2 2286 assert source_shape.rank >= 1 2287 2288 source_rps = source_shape._as_row_partitions() # pylint: disable=protected-access 2289 2290 target_rps = target_shape._as_row_partitions() # pylint: disable=protected-access 2291 2292 assert len(target_rps) >= 1 2293 assert len(source_rps) <= len(target_rps) 2294 source_nrows = source_shape[0] 2295 if len(source_rps) < len(target_rps): 2296 # Note: this includes the case where len(source_rps)==0. 2297 # Here we begin at -1, one dimension before source_rps[0]. 2298 # neg_one_source_rp | neg_one_target_rp=target_rps[-(len(source_rps)+1)] 2299 # source_rps[0] | target_rps[-len(source_rps)] 2300 # source_rps[1] | target_rps[1-len(source_rps)] 2301 # ... | ... 2302 # source_rps[-1] | target_rps[-1] 2303 neg_one_source_rp = RowPartition.from_uniform_row_length( 2304 uniform_row_length=source_nrows, nrows=1, nvals=source_nrows) 2305 neg_one_target_rp = target_rps[-(len(source_rps) + 1)] 2306 neg_one_broadcaster = _LayerBroadcaster.get_singleton_broadcaster( 2307 neg_one_target_rp.nrows()) 2308 zeroth_broadcaster = neg_one_broadcaster.next_layer(neg_one_source_rp, 2309 neg_one_target_rp) 2310 target_rps_tail = target_rps[-len(source_rps):] if len( 2311 source_rps) >= 1 else [] 2312 2313 layers = _get_layer_broadcasters_from_rps(zeroth_broadcaster, source_rps, 2314 target_rps_tail) 2315 return _Broadcaster(source_shape, target_shape, layers) 2316 else: 2317 assert len(target_rps) == len(source_rps) 2318 zeroth_broadcaster = _LayerBroadcaster.first_layer(source_rps[0].nrows(), 2319 target_rps[0].nrows()) 2320 layers = _get_layer_broadcasters_from_rps(zeroth_broadcaster, source_rps, 2321 target_rps) 2322 2323 return _Broadcaster(source_shape, target_shape, layers) 2324 2325 2326def _get_identity_broadcaster(shape): 2327 """Gets a Broadcaster for two identical shapes.""" 2328 if shape.rank is None: 2329 raise ValueError("Shape must have a defined rank") 2330 layers = [ 2331 _LayerBroadcaster.get_identity_broadcaster( 2332 shape._num_slices_in_dimension(i)) for i in range(shape.rank) # pylint: disable=protected-access 2333 ] 2334 return _Broadcaster(shape, shape, layers) 2335 2336 2337def _broadcast_dynamic_shape_one_layer(a, b): 2338 """Broadcast two vectors, given their shapes. 2339 2340 Args: 2341 a: the number of rows in a. 2342 b: the number of rows in b. 2343 2344 Returns: 2345 (layer_a, layer_b, target_shape) 2346 layer_a is a _LayerBroadcaster from a to the target_shape. 2347 layer_b is a _LayerBroadcaster from b to the target_shape. 2348 target_shape is the target_shape 2349 2350 Raises: 2351 InvalidArgumentError if the shapes are not consistent. 2352 """ 2353 a_0 = a[0] 2354 b_0 = b[0] 2355 2356 def broadcast_from_a(): 2357 # Assumes a_0 == 1 2358 a_layer = array_ops.zeros(b_0, dtype=b_0.dtype) 2359 b_layer = math_ops.range(b_0) 2360 target = b 2361 return [a_layer, b_layer, target] 2362 2363 a_static = tensor_util.constant_value(a) 2364 if a_static is not None and a_static[0] == 1: 2365 [a_gi, b_gi, target] = broadcast_from_a() 2366 a_layer = _LayerBroadcaster.from_gather_index(a_gi) 2367 b_layer = _LayerBroadcaster.from_gather_index(b_gi) 2368 return [a_layer, b_layer, target] 2369 2370 def broadcast_from_b(): 2371 # Assumes b_0 == 1 2372 a_layer = math_ops.range(a_0) 2373 b_layer = array_ops.zeros(a_0, dtype=a_0.dtype) 2374 target = a 2375 return [a_layer, b_layer, target] 2376 2377 b_static = tensor_util.constant_value(b) 2378 if b_static is not None and b_static[0] == 1: 2379 [a_gi, b_gi, target] = broadcast_from_b() 2380 a_layer = _LayerBroadcaster.from_gather_index(a_gi) 2381 b_layer = _LayerBroadcaster.from_gather_index(b_gi) 2382 return [a_layer, b_layer, target] 2383 2384 def broadcast_noop(): 2385 # Assumes a_0 == 1 2386 a_layer = math_ops.range(a_0) 2387 b_layer = math_ops.range(b_0) 2388 target = b 2389 return [a_layer, b_layer, target] 2390 2391 can_broadcast_from_a = math_ops.equal(a_0, 1) 2392 can_broadcast_from_b = math_ops.equal(b_0, 1) 2393 2394 def broadcast_not_from_a(): 2395 return control_flow_ops.cond( 2396 can_broadcast_from_b, true_fn=broadcast_from_b, false_fn=broadcast_noop) 2397 2398 nrows_equal = math_ops.equal(a_0, b_0) 2399 can_broadcast = math_ops.logical_or( 2400 can_broadcast_from_a, 2401 math_ops.logical_or(can_broadcast_from_b, nrows_equal)) 2402 2403 check_can_broadcast = check_ops.assert_equal( 2404 can_broadcast, True, message="Cannot broadcast") 2405 2406 results = control_flow_ops.cond( 2407 can_broadcast_from_a, 2408 true_fn=broadcast_from_a, 2409 false_fn=broadcast_not_from_a) 2410 2411 results = [ 2412 control_flow_ops.with_dependencies([check_can_broadcast], x) 2413 for x in results 2414 ] 2415 [a_gi, b_gi, target] = results 2416 a_layer = _LayerBroadcaster.from_gather_index(a_gi) 2417 b_layer = _LayerBroadcaster.from_gather_index(b_gi) 2418 return [a_layer, b_layer, target] 2419 2420 2421def _broadcast_dynamic_shape_first_layer(a_0, b_0): 2422 """Broadcast the first layer of two dynamic shapes given the dimensions. 2423 2424 Args: 2425 a_0: the number of rows in a. 2426 b_0: the number of rows in b. 2427 2428 Returns: 2429 (use_a, layer_a, layer_b) 2430 where use_a is true if the target provably equals a, false otherwise. 2431 layer_a is a _LayerBroadcaster from a to the target. 2432 layer_b is a _LayerBroadcaster from b to the target. 2433 """ 2434 def broadcast_from_a(): 2435 # Assumes a_0 == 1 2436 a_layer = array_ops.zeros(b_0, dtype=b_0.dtype) 2437 b_layer = math_ops.range(b_0) 2438 return [a_layer, b_layer] 2439 2440 static_a_0 = tensor_util.constant_value(a_0) 2441 static_b_0 = tensor_util.constant_value(b_0) 2442 if static_a_0 is not None: 2443 if static_a_0 == static_b_0: 2444 id_broadcaster = _LayerBroadcaster.get_identity_broadcaster( 2445 static_a_0, dtype=a_0.dtype) 2446 return [id_broadcaster, id_broadcaster] 2447 elif static_a_0 == 1: 2448 return [ 2449 _LayerBroadcaster.get_singleton_broadcaster(b_0), 2450 _LayerBroadcaster.get_identity_broadcaster(b_0) 2451 ] 2452 2453 if static_b_0 == 1: 2454 return [ 2455 _LayerBroadcaster.get_identity_broadcaster(a_0), 2456 _LayerBroadcaster.get_singleton_broadcaster(a_0) 2457 ] 2458 2459 def broadcast_from_b(): 2460 # Assumes b_0 == 1 2461 a_layer = math_ops.range(a_0) 2462 b_layer = array_ops.zeros(a_0, dtype=a_0.dtype) 2463 return [a_layer, b_layer] 2464 2465 def broadcast_noop(): 2466 # Assumes a_0 == b_0 2467 a_layer = math_ops.range(a_0) 2468 b_layer = math_ops.range(b_0) 2469 return [a_layer, b_layer] 2470 2471 can_broadcast_from_a = math_ops.equal(a_0, constant_op.constant(1, a_0.dtype)) 2472 can_broadcast_from_b = math_ops.equal(b_0, constant_op.constant(1, b_0.dtype)) 2473 2474 def broadcast_not_from_a(): 2475 return control_flow_ops.cond( 2476 can_broadcast_from_b, true_fn=broadcast_from_b, false_fn=broadcast_noop) 2477 2478 # Ideally, this would only block control flow on broadcast_noop, but 2479 # the control flow doesn't seem to work. 2480 can_broadcast = math_ops.logical_or( 2481 math_ops.logical_or(can_broadcast_from_a, can_broadcast_from_b), 2482 math_ops.equal(a_0, b_0)) 2483 2484 result = control_flow_ops.cond( 2485 can_broadcast_from_a, 2486 true_fn=broadcast_from_a, 2487 false_fn=broadcast_not_from_a) 2488 2489 return [ 2490 _LayerBroadcaster.from_gather_index( 2491 control_flow_ops.with_dependencies( 2492 [check_ops.assert_equal(can_broadcast, True)], x)) for x in result 2493 ] 2494 2495 2496def _broadcast_half( 2497 ac_0: _LayerBroadcaster, 2498 a_1: RowPartition) -> Tuple[_LayerBroadcaster, RowPartition]: 2499 """Does a NOOP broadcast of a_1. 2500 2501 *-ac_0-->* 2502 | | 2503 a_1 c_1 2504 | | 2505 V V 2506 *-ac_1-->* 2507 2508 Note that by definition this cannot fail: there is always a well-defined 2509 NOOP broadcast. This is usually intended as half of broadcasting two shapes 2510 together. 2511 Args: 2512 ac_0: previous LayerBroadcaster 2513 a_1: previous RowPartition 2514 2515 Returns: 2516 [ac_1, c_1] where ac_1 is the next LayerBroadcaster, and c_1 is the 2517 broadcast RowPartition 2518 """ 2519 c_1 = ac_0.broadcast_row_partition(a_1) 2520 old_value_rowids = array_ops.gather(ac_0.gather_index, c_1.value_rowids()) 2521 old_row_starts = array_ops.gather(a_1.row_splits(), old_value_rowids) 2522 gather_index = old_row_starts + c_1.offsets_in_rows() 2523 return [_LayerBroadcaster.from_gather_index(gather_index), c_1] 2524 2525 2526def _broadcast_dynamic_shape_next_layer_half_ragged( 2527 ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition, 2528 b_1: RowPartition 2529) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]: 2530 r"""Broadcast target and next layer broadcaster of two dynamic shapes. 2531 2532 a_1 is uniform, and b_1 is ragged. 2533 *--ac_0-->*<--bc_0--* 2534 | | | 2535 a_1 c_1 b_1 2536 | | | 2537 V V V 2538 *--ac_1-->*<--bc_1--* 2539 2540 Args: 2541 ac_0: _LayerBroadcaster from a to c in the previous layer. 2542 bc_0: _LayerBroadcaster from b to c in the previous layer. 2543 a_1: a uniform RowPartition for the next layer of a. 2544 b_1: a ragged RowPartition for the next layer of b. 2545 2546 Returns: 2547 (c_1, ac_1, bc_1) 2548 c_1: a RowPartition for the next layer of the dynamic shape. 2549 ac_1: _LayerBroadcaster from a to c in the next layer. 2550 bc_1: _LayerBroadcaster from b to c in the next layer. 2551 """ 2552 if not isinstance(ac_0, _LayerBroadcaster): 2553 raise TypeError("ac_0 should be a _LayerBroadcaster") 2554 if not isinstance(bc_0, _LayerBroadcaster): 2555 raise TypeError("bc_0 should be a _LayerBroadcaster") 2556 if not isinstance(a_1, RowPartition): 2557 raise TypeError("a_1 should be a RowPartition") 2558 if not isinstance(b_1, RowPartition): 2559 raise TypeError("b_1 should be a RowPartition") 2560 2561 assert a_1.is_uniform() 2562 assert not b_1.is_uniform() 2563 2564 static_a_1 = tensor_util.constant_value(a_1.uniform_row_length()) 2565 if static_a_1 == 1: 2566 [bc_1, c_1b] = _broadcast_half(bc_0, b_1) 2567 ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids()) 2568 c_1 = RowPartition.from_row_splits(c_1b.row_splits()) 2569 ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index) 2570 bc_1 = _LayerBroadcaster.from_gather_index(bc_1.gather_index) 2571 return [c_1, ac_1, bc_1] 2572 2573 def broadcast_noop(): 2574 # The sides must be "equal". 2575 [ac_1, c_1a] = _broadcast_half(ac_0, a_1) 2576 [bc_1, c_1b] = _broadcast_half(bc_0, b_1) 2577 checks = [check_ops.assert_equal(c_1a.row_splits(), c_1b.row_splits())] 2578 return [ 2579 control_flow_ops.with_dependencies(checks, x) 2580 for x in [a_1.row_splits(), ac_1.gather_index, bc_1.gather_index] 2581 ] 2582 2583 def broadcast_a(): 2584 [bc_1, c_1b] = _broadcast_half(bc_0, b_1) 2585 ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids()) 2586 return [ 2587 c_1b.row_splits(), 2588 ac_1_gather_index, 2589 bc_1.gather_index, 2590 ] 2591 2592 can_broadcast_a = math_ops.equal(a_1.uniform_row_length(), 1) 2593 2594 [c_1_row_splits, ac_1_gather_index, 2595 bc_1_gather_index] = control_flow_ops.cond( 2596 can_broadcast_a, true_fn=broadcast_a, false_fn=broadcast_noop) 2597 2598 c_1 = RowPartition.from_row_splits(c_1_row_splits) 2599 ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index) 2600 bc_1 = _LayerBroadcaster.from_gather_index(bc_1_gather_index) 2601 return [c_1, ac_1, bc_1] 2602 2603 2604def _broadcast_dynamic_shape_next_layer_both_uniform( 2605 ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition, 2606 b_1: RowPartition 2607) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]: 2608 r"""Broadcast target and next layer broadcaster of two uniform dynamic shapes. 2609 2610 *--ac_0-->*<--bc_0--* 2611 | | | 2612 a_1 c_1 b_1 2613 | | | 2614 V V V 2615 *--ac_1-->*<--bc_1--* 2616 2617 Args: 2618 ac_0: _LayerBroadcaster from a to c in the previous layer. 2619 bc_0: _LayerBroadcaster from b to c in the previous layer. 2620 a_1: a RowPartition for the next layer of a. 2621 b_1: a RowPartition for the next layer of b. 2622 2623 Returns: 2624 (c_1, ac_1, bc_1) 2625 c_1: a RowPartition for the next layer of the dynamic shape. 2626 ac_1: _LayerBroadcaster from a to c in the next layer. 2627 bc_1: _LayerBroadcaster from b to c in the next layer. 2628 """ 2629 if not isinstance(ac_0, _LayerBroadcaster): 2630 raise TypeError("ac_0 should be a _LayerBroadcaster") 2631 if not isinstance(bc_0, _LayerBroadcaster): 2632 raise TypeError("bc_0 should be a _LayerBroadcaster") 2633 if not isinstance(a_1, RowPartition): 2634 raise TypeError("a_1 should be a RowPartition") 2635 if not isinstance(b_1, RowPartition): 2636 raise TypeError("b_1 should be a RowPartition") 2637 assert a_1.is_uniform() 2638 assert b_1.is_uniform() 2639 2640 static_a_1 = tensor_util.constant_value(a_1.uniform_row_length()) 2641 static_b_1 = tensor_util.constant_value(b_1.uniform_row_length()) 2642 2643 if static_a_1 is not None: 2644 if static_a_1 == static_b_1: 2645 # Here, this dimension is the same, but we may have to broadcast previous 2646 # dimensions. 2647 [ac_1, _] = _broadcast_half(ac_0, a_1) 2648 [bc_1, _] = _broadcast_half(bc_0, b_1) 2649 c_1 = RowPartition.from_uniform_row_length( 2650 static_a_1, 2651 nrows=ac_0.dest_nrows()) 2652 return [c_1, ac_1, bc_1] 2653 elif static_a_1 == 1: 2654 [bc_1, c_1b] = _broadcast_half(bc_0, b_1) 2655 ac_1 = _LayerBroadcaster.from_gather_index( 2656 array_ops.gather(ac_0.gather_index, c_1b.value_rowids())) 2657 c_1 = RowPartition.from_uniform_row_length( 2658 b_1.uniform_row_length(), 2659 nrows=bc_0.dest_nrows()) 2660 return [c_1, ac_1, bc_1] 2661 2662 if static_b_1 == 1: 2663 [ac_1, c_1a] = _broadcast_half(ac_0, a_1) 2664 bc_1 = _LayerBroadcaster.from_gather_index( 2665 array_ops.gather(bc_0.gather_index, c_1a.value_rowids())) 2666 c_1 = RowPartition.from_uniform_row_length( 2667 a_1.uniform_row_length(), 2668 nrows=ac_0.dest_nrows()) 2669 return [c_1, ac_1, bc_1] 2670 2671 def broadcast_noop(): 2672 # Assumes a_1.uniform_row_length() == b_1.uniform_row_length() 2673 # Both sides broadcast to a single shape. 2674 [ac_1, _] = _broadcast_half(ac_0, a_1) 2675 [bc_1, _] = _broadcast_half(bc_0, b_1) 2676 return [a_1.uniform_row_length(), ac_1.gather_index, bc_1.gather_index] 2677 2678 def broadcast_a(): 2679 [bc_1, c_1b] = _broadcast_half(bc_0, b_1) 2680 ac_1_gather_index = array_ops.gather(ac_0.gather_index, c_1b.value_rowids()) 2681 return [ 2682 b_1.uniform_row_length(), 2683 ac_1_gather_index, 2684 bc_1.gather_index, 2685 ] 2686 2687 def broadcast_b(): 2688 [ac_1, c_1a] = _broadcast_half(ac_0, a_1) 2689 bc_1_gather_index = array_ops.gather(bc_0.gather_index, c_1a.value_rowids()) 2690 return [a_1.uniform_row_length(), ac_1.gather_index, bc_1_gather_index] 2691 2692 can_broadcast_b = math_ops.equal(b_1.uniform_row_length(), 1) 2693 2694 def no_broadcast_a(): 2695 return control_flow_ops.cond( 2696 can_broadcast_b, true_fn=broadcast_b, false_fn=broadcast_noop) 2697 2698 can_broadcast_a = math_ops.equal(a_1.uniform_row_length(), 1) 2699 2700 broadcast_asserts = [ 2701 check_ops.assert_equal( 2702 math_ops.logical_or( 2703 math_ops.logical_or(can_broadcast_a, can_broadcast_b), 2704 math_ops.equal(a_1.uniform_row_length(), 2705 b_1.uniform_row_length())), True) 2706 ] 2707 2708 result = control_flow_ops.cond( 2709 can_broadcast_a, true_fn=broadcast_a, false_fn=no_broadcast_a) 2710 2711 [c_1_uniform_row_length, ac_1_gather_index, bc_1_gather_index] = [ 2712 control_flow_ops.with_dependencies(broadcast_asserts, x) for x in result 2713 ] 2714 2715 c_1 = RowPartition.from_uniform_row_length( 2716 c_1_uniform_row_length, 2717 nvals=c_1_uniform_row_length * ac_0.dest_nrows(), 2718 nrows=ac_0.dest_nrows()) 2719 ac_1 = _LayerBroadcaster.from_gather_index(ac_1_gather_index) 2720 bc_1 = _LayerBroadcaster.from_gather_index(bc_1_gather_index) 2721 return [c_1, ac_1, bc_1] 2722 2723 2724def _broadcast_dynamic_shape_next_layer( 2725 ac_0: _LayerBroadcaster, bc_0: _LayerBroadcaster, a_1: RowPartition, 2726 b_1: RowPartition 2727) -> Tuple[RowPartition, _LayerBroadcaster, _LayerBroadcaster]: 2728 r"""Broadcast target and next layer broadcaster of two dynamic shapes. 2729 2730 *--ac_0-->*<--bc_0--* 2731 | | | 2732 a_1 c_1 b_1 2733 | | | 2734 V V V 2735 *--ac_1-->*<--bc_1--* 2736 2737 Args: 2738 ac_0: _LayerBroadcaster from a to c in the previous layer. 2739 bc_0: _LayerBroadcaster from b to c in the previous layer. 2740 a_1: a RowPartition for the next layer of a. 2741 b_1: a RowPartition for the next layer of b. 2742 2743 Returns: 2744 (c_1, ac_1, bc_1) 2745 c_1: a RowPartition for the next layer of the dynamic shape. 2746 ac_1: _LayerBroadcaster from a to c in the next layer. 2747 bc_1: _LayerBroadcaster from b to c in the next layer. 2748 """ 2749 if not isinstance(ac_0, _LayerBroadcaster): 2750 raise TypeError("ac_0 should be a _LayerBroadcaster") 2751 if not isinstance(bc_0, _LayerBroadcaster): 2752 raise TypeError("bc_0 should be a _LayerBroadcaster") 2753 if not isinstance(a_1, RowPartition): 2754 raise TypeError("a_1 should be a RowPartition") 2755 if not isinstance(b_1, RowPartition): 2756 raise TypeError("b_1 should be a RowPartition") 2757 2758 if a_1.is_uniform(): 2759 if b_1.is_uniform(): 2760 return _broadcast_dynamic_shape_next_layer_both_uniform( 2761 ac_0, bc_0, a_1, b_1) 2762 else: 2763 return _broadcast_dynamic_shape_next_layer_half_ragged( 2764 ac_0, bc_0, a_1, b_1) 2765 else: 2766 if b_1.is_uniform(): 2767 [c_1, bc_1, ac_1] = _broadcast_dynamic_shape_next_layer_half_ragged( # pylint: disable=arguments-out-of-order 2768 bc_0, ac_0, b_1, a_1) 2769 return (c_1, ac_1, bc_1) 2770 else: 2771 # If neither shape is uniform, we cannot broadcast the dimension. 2772 [ac_1, c_1a] = _broadcast_half(ac_0, a_1) 2773 [bc_1, c_1b] = _broadcast_half(bc_0, b_1) 2774 check_valid = [ 2775 check_ops.assert_equal(c_1a.row_splits(), c_1b.row_splits()) 2776 ] 2777 return (c_1a._with_dependencies(check_valid), # pylint: disable=protected-access 2778 ac_1.with_dependencies(check_valid), 2779 bc_1.with_dependencies(check_valid)) 2780 2781 2782def _broadcast_dynamic_shape_from_rps( 2783 a_zero: _LayerBroadcaster, b_zero: _LayerBroadcaster, 2784 a_rps: Sequence[RowPartition], b_rps: Sequence[RowPartition] 2785) -> Tuple[Sequence[RowPartition], Sequence[_LayerBroadcaster], 2786 Sequence[_LayerBroadcaster]]: 2787 """Create BroadcastLayers from two shapes to a target shape. 2788 2789 2790 *--a_zero->*<-b_zero-* 2791 | | | 2792 a_rps[0] c_rps[0] b_rps[0] 2793 | | | 2794 V V V 2795 *--ac[1]-->*<-bc[1]--* 2796 | | | 2797 a_rps[1] c_rps[0] b_rps[1] 2798 | | | 2799 V V V 2800 *--ac[2]-->*<-bc[2]--* 2801 2802 Note: ac[0]=a_zero, and bc[0]=b_zero. 2803 Args: 2804 a_zero: broadcaster from rows of a_rps[0] to target shape. 2805 b_zero: broadcaster from rows of b_rps[0] to target shape. 2806 a_rps: RowPartitions of first shape. 2807 b_rps: RowPartitions of second shape, equal in length to a_rps. 2808 2809 Returns: 2810 (c_rps, ac, bc) where: 2811 c_rps: RowPartitions of target shape. 2812 ac: layers broadcasting from the first shape. 2813 bc: layers broadcasting from the second shape. 2814 """ 2815 assert len(a_rps) == len(b_rps) 2816 if a_rps: 2817 (c_1, ac_1, 2818 bc_1) = _broadcast_dynamic_shape_next_layer(a_zero, b_zero, a_rps[0], 2819 b_rps[0]) 2820 (c_suffix, a_layers, 2821 b_layers) = _broadcast_dynamic_shape_from_rps(ac_1, bc_1, a_rps[1:], 2822 b_rps[1:]) 2823 2824 return ([c_1] + c_suffix, [ac_1] + a_layers, [bc_1] + b_layers) 2825 else: 2826 return ([], [], []) 2827 2828 2829def _get_broadcast_num_row_partitions(a: DynamicRaggedShape, 2830 b: DynamicRaggedShape): 2831 """Returns broadcast_dynamic_shape(a, b).num_row_partitions.""" 2832 # Assumes rank and num_row_partitions are not None. 2833 if (a.num_row_partitions == 0 and b.num_row_partitions == 0): 2834 return 0 2835 expanded_num_row_partitions_a = a.num_row_partitions + max(0, b.rank - a.rank) 2836 expanded_num_row_partitions_b = b.num_row_partitions + max(0, a.rank - b.rank) 2837 2838 if a.num_row_partitions == 0: 2839 return expanded_num_row_partitions_b 2840 2841 if b.num_row_partitions == 0: 2842 return expanded_num_row_partitions_a 2843 2844 return max(expanded_num_row_partitions_a, expanded_num_row_partitions_b) 2845 2846 2847# pylint: disable=protected-access 2848def _broadcast_dynamic_shape_extended_complete( 2849 a: DynamicRaggedShape, b: DynamicRaggedShape, b_rps: Sequence[RowPartition], 2850 c_suffix: Sequence[RowPartition], ac: Sequence[_LayerBroadcaster], 2851 bc_suffix: Sequence[_LayerBroadcaster] 2852) -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster]: 2853 """Helper for broadcast_dynamic_shape_extended.""" 2854 c_prefix = b_rps[:-len(c_suffix)] 2855 bc_prefix_length = b.rank - len(bc_suffix) 2856 bc_prefix = [ 2857 _LayerBroadcaster.get_identity_broadcaster(b._num_slices_in_dimension(i)) 2858 for i in range(bc_prefix_length) 2859 ] 2860 c_num_row_partitions = _get_broadcast_num_row_partitions(a, b) 2861 2862 c_raw = DynamicRaggedShape.from_row_partitions(c_prefix + tuple(c_suffix)) 2863 c = c_raw._with_num_row_partitions(c_num_row_partitions) 2864 return (c, _Broadcaster(a, c, ac), _Broadcaster(b, c, bc_prefix + bc_suffix)) 2865 2866 2867def _broadcast_dynamic_shape_extended_helper( 2868 a: DynamicRaggedShape, b: DynamicRaggedShape 2869) -> Tuple[DynamicRaggedShape, _Broadcaster, _Broadcaster]: 2870 """Helper for broadcast_dynamic_shape_extended. 2871 2872 Here, we force: 2873 a.rank <= b.rank 2874 2 <= b.rank 2875 1 <= a.rank 2876 Args: 2877 a: a DynamicRaggedShape 2878 b: a DynamicRaggedShape 2879 2880 Returns: 2881 A triple of a shape and two broadcasters. 2882 """ 2883 assert a.rank <= b.rank 2884 assert 2 <= b.rank 2885 assert 1 <= a.rank 2886 a_rps = a._as_row_partitions() # pylint: disable=protected-access 2887 b_rps = b._as_row_partitions() # pylint: disable=protected-access 2888 2889 if len(a_rps) < len(b_rps): 2890 # Note: this includes the case where len(a_rps)==0. 2891 # Here we begin at -1, one dimension before a_rps[0]. 2892 # neg_one_a_rp | b_rps[-(len(a_rps)+1)] 2893 # a_rps[0] | b_rps[-len(a_rps)] 2894 # a_rps[1] | b_rps[1-len(a_rps)] 2895 # ... | ... 2896 # a_rps[-1] | b_rps[-1] 2897 2898 a_nrows = a[0] 2899 a_nrows_static = tensor_util.constant_value(a_nrows) 2900 if a_nrows_static is not None: 2901 a_nrows = a_nrows_static 2902 2903 neg_one_a_rp = RowPartition.from_uniform_row_length( 2904 uniform_row_length=a_nrows, nrows=1, nvals=a_nrows) 2905 neg_one_b_rp = b_rps[-(len(a_rps) + 1)] 2906 (neg_one_ac, neg_one_bc) = _broadcast_dynamic_shape_first_layer( 2907 constant_op.constant(1, dtype=b_rps[0].dtype), neg_one_b_rp.nrows()) 2908 2909 # The first part of the solution. 2910 (c_zero, ac_zero, 2911 bc_zero) = _broadcast_dynamic_shape_next_layer(neg_one_ac, neg_one_bc, 2912 neg_one_a_rp, neg_one_b_rp) 2913 b_rps_tail = b_rps[-len(a_rps):] if len(a_rps) >= 1 else [] 2914 2915 (c_suffix, ac_layers, 2916 bc_layers) = _broadcast_dynamic_shape_from_rps(ac_zero, bc_zero, a_rps, 2917 b_rps_tail) 2918 2919 return _broadcast_dynamic_shape_extended_complete( 2920 a=a, 2921 b=b, 2922 b_rps=b_rps, 2923 c_suffix=[c_zero] + c_suffix, 2924 ac=[ac_zero] + ac_layers, 2925 bc_suffix=[neg_one_bc, bc_zero] + bc_layers) 2926 2927 else: 2928 assert len(a_rps) == len(b_rps) 2929 (ac_zero, 2930 bc_zero) = _broadcast_dynamic_shape_first_layer(a_rps[0].nrows(), 2931 b_rps[0].nrows()) 2932 2933 (c_rps, a_layers, 2934 b_layers) = _broadcast_dynamic_shape_from_rps(ac_zero, bc_zero, a_rps, 2935 b_rps) 2936 return _broadcast_dynamic_shape_extended_complete( 2937 a=a, 2938 b=b, 2939 b_rps=b_rps, 2940 c_suffix=c_rps, 2941 ac=[ac_zero] + a_layers, 2942 bc_suffix=[bc_zero] + b_layers) 2943 2944 2945def _fix_start_index(index, rank, num_row_partitions): 2946 """Slice indexes are always silently truncated.""" 2947 if index < 0: 2948 if rank is None: 2949 raise ValueError( 2950 "Rank must be known to use __getitem__ on a negative index.") 2951 index = rank + index 2952 if index < 0: 2953 index = 0 2954 if (num_row_partitions > 0 and index <= num_row_partitions + 1): 2955 # The rank is always >= num_row_partitions + 1 if num_row_partitions > 0. 2956 return index 2957 if index == 0: 2958 return index 2959 if rank is None: 2960 raise ValueError("Rank must be known to use __getitem__ on a large index.") 2961 if index >= rank: 2962 index = rank 2963 return index 2964 2965 2966def _fix_stop_index(index, rank): 2967 """Slice indexes are always silently truncated.""" 2968 if index is None: 2969 if rank is None: 2970 raise ValueError("Rank must be known to use __getitem__ without a stop.") 2971 index = rank 2972 if index < 0: 2973 if rank is None: 2974 raise ValueError( 2975 "Rank must be known to use __getitem__ on a negative index.") 2976 index = rank + index 2977 if index < 0: 2978 index = 0 2979 if rank is not None: 2980 index = min(rank, index) 2981 return index 2982 2983 2984def _first_layer_gather_index(nrows_source, nrows_target): 2985 """Return the first layer gather_index. 2986 2987 Args: 2988 nrows_source: the number of rows in the source. 2989 nrows_target: the number of rows in the target. 2990 2991 Returns: 2992 A tensor, usable as a gather_index for a _LayerBroadcaster. 2993 """ 2994 2995 def gi_broadcast_first(): 2996 return array_ops.zeros(nrows_target, dtype=nrows_target.dtype) 2997 2998 def gi_no_broadcast_first(): 2999 gather_index = math_ops.range(nrows_target, dtype=nrows_target.dtype) 3000 return gather_index 3001 3002 do_broadcast = math_ops.equal(nrows_source, 3003 constant_op.constant(1, nrows_source.dtype)) 3004 nrows_equal = math_ops.equal(nrows_source, nrows_target) 3005 can_broadcast = check_ops.assert_equal( 3006 math_ops.logical_or(do_broadcast, nrows_equal), 3007 True, 3008 message="Cannot broadcast") 3009 3010 gather_index = control_flow_ops.cond( 3011 do_broadcast, true_fn=gi_broadcast_first, false_fn=gi_no_broadcast_first) 3012 3013 return control_flow_ops.with_dependencies([can_broadcast], gather_index) 3014 3015 3016def _next_layer_gather_index(bc, original_rp, broadcast_rp): 3017 r"""Create the next layer gather_index whether or not a broadcast happens. 3018 3019 *----------bc-------->* 3020 | | 3021 original_rp broadcast_rp 3022 | | 3023 \|/ \|/ 3024 *--next_broadcaster-->* 3025 3026 Args: 3027 bc: the old broadcaster. 3028 original_rp: the original row partition. 3029 broadcast_rp: the target row partition. 3030 3031 Returns: 3032 the gather_index for next_broadcaster. 3033 Raises: 3034 InvalidArgumentError if the shapes are incompatible. 3035 """ 3036 old_value_rowids = array_ops.gather(bc.gather_index, 3037 broadcast_rp.value_rowids()) 3038 3039 def gi_no_broadcast(): 3040 # TODO(martinz): decide if row_splits or row_starts should be used here. 3041 old_row_starts = array_ops.gather(original_rp.row_splits(), 3042 old_value_rowids) 3043 expected_row_lengths = array_ops.gather( 3044 params=original_rp.row_lengths(), indices=bc.gather_index) 3045 actual_row_lengths = broadcast_rp.row_lengths() 3046 check_valid = check_ops.assert_equal( 3047 expected_row_lengths, actual_row_lengths, message="Cannot broadcast") 3048 gather_index = old_row_starts + broadcast_rp.offsets_in_rows() 3049 return control_flow_ops.with_dependencies([check_valid], gather_index) 3050 3051 def gi_broadcast(): 3052 # Several optimizations can occur here. 3053 # old_row_starts == old_value_rowids, because: 3054 # if you are broadcasting, then the source has uniform row length of 1, 3055 # implying original_rp.row_splits == tf.range(orgininal_rp.nvals + 1) 3056 # When broadcasting, there is no need to add offsets to the 3057 # source, because the source has size 1. 3058 # Also, this is always valid, because we enforce source and destination 3059 # have uniform_row_length. 3060 return old_value_rowids 3061 3062 if not original_rp.is_uniform(): 3063 return gi_no_broadcast() 3064 3065 do_broadcast = math_ops.equal(original_rp.uniform_row_length(), 3066 constant_op.constant(1, original_rp.dtype)) 3067 gather_index = control_flow_ops.cond( 3068 do_broadcast, true_fn=gi_broadcast, false_fn=gi_no_broadcast) 3069 3070 return gather_index 3071 3072 3073def _flat_values_shape(rt): 3074 if isinstance(rt, ragged_tensor.RaggedTensor): 3075 return array_ops.shape(rt.flat_values) 3076 return rt.flat_values.shape 3077 3078 3079def _to_row_partitions_and_nvals_from_lengths( 3080 lengths: Sequence[Union[int, Sequence[int]]], 3081 dtype=None) -> Tuple[Sequence[RowPartition], int]: 3082 """Allow ragged and uniform shapes to be specified. 3083 3084 For example, [2, [2,1], 2] represents a shape like: 3085 [[[0, 0], [0, 0]], [[0, 0]]] 3086 3087 Args: 3088 lengths: a list of integers and lists of integers. 3089 dtype: dtype of the shape (tf.int32 or tf.int64) 3090 3091 Returns: 3092 a sequence of RowPartitions, and the number of values of the last partition. 3093 """ 3094 size_so_far = lengths[0] 3095 result = [] 3096 for current_lengths in lengths[1:]: 3097 if isinstance(current_lengths, int): 3098 nrows = size_so_far 3099 nvals = current_lengths * nrows 3100 size_so_far = nvals 3101 result.append( 3102 RowPartition.from_uniform_row_length( 3103 current_lengths, nvals, nrows=nrows, dtype_hint=dtype)) 3104 else: 3105 if size_so_far != len(current_lengths): 3106 raise ValueError("Shape not consistent.") 3107 result.append( 3108 RowPartition.from_row_lengths(current_lengths, dtype_hint=dtype)) 3109 size_so_far = sum(current_lengths) 3110 return (result, size_so_far) 3111 3112 3113def _element_to_string(x): 3114 """element to a string within a list.""" 3115 if x is Ellipsis: 3116 return "..." 3117 if isinstance(x, str): 3118 return "'" + x + "'" 3119 return str(x) 3120 3121 3122def _list_tail_with_ellipsis(arr): 3123 """Print the tail of a list where the list might have an ellipsis.""" 3124 if not arr: 3125 return "]" 3126 else: 3127 return ", " + _element_to_string(arr[0]) + _list_tail_with_ellipsis(arr[1:]) 3128 3129 3130def _list_with_ellipsis_to_str(arr): 3131 """Print a list that might have ellipsis.""" 3132 if not arr: 3133 return "[]" 3134 return "[" + _element_to_string(arr[0]) + _list_tail_with_ellipsis(arr[1:]) 3135 3136 3137def _is_int_or_tuple_of_ints(x): 3138 if isinstance(x, int): 3139 return True 3140 if not isinstance(x, tuple): 3141 return False 3142 for y in x: 3143 if not isinstance(y, int): 3144 return False 3145 return True 3146 3147 3148def _alt_inner_shape_from_tensor_shape(shape, dtype, new_inner_rank): 3149 """Helper for _alt_inner_shape, used directly in _with_num_row_partitions.""" 3150 if new_inner_rank == 1: 3151 return constant_op.constant([shape.num_elements()], dtype=dtype) 3152 new_inner_rank_tail_length = new_inner_rank - 1 3153 inner_shape_tail = shape[-new_inner_rank_tail_length:].as_list() 3154 first_dim = shape[:-new_inner_rank_tail_length].num_elements() 3155 return constant_op.constant([first_dim] + inner_shape_tail, dtype=dtype) 3156 3157 3158def _safe_floor_div(dividend: tensor_shape.Dimension, 3159 divisor: tensor_shape.Dimension) -> tensor_shape.Dimension: 3160 if tensor_shape.dimension_value(divisor) == 0: 3161 return None 3162 return dividend // divisor 3163 3164 3165# TODO(b/218932570) 3166def _reduce_prod_patch(x): 3167 if x.dtype == dtypes.int64: 3168 return math_ops.cast( 3169 math_ops.reduce_prod(math_ops.cast(x, dtypes.int32)), dtypes.int64) 3170 return math_ops.reduce_prod(x) 3171 3172 3173# Type alias for shape encoded as a DynamicRaggedShape or a Tensor. 3174DenseOrRaggedShape = Union[DynamicRaggedShape, core.TensorLike] 3175 3176 3177def _merge_row_partitions( 3178 row_partitions: Sequence[RowPartition]) -> RowPartition: 3179 # TODO(martinz): handle uniform splits. 3180 # TODO(martinz): consider using value_row_ids if present. 3181 # Note: this probably won't be called with len(row_partitions)==1, so no 3182 # need to optimize. 3183 row_splits = row_partitions[0].row_splits() 3184 for rp in row_partitions[1:]: 3185 row_splits = array_ops.gather(rp.row_splits(), row_splits) 3186 return RowPartition.from_row_splits(row_splits) 3187 3188 3189def _merge_inner_shape( 3190 inner_shape: ops.Tensor, 3191 static_inner_shape: tensor_shape.TensorShape, 3192 outer_axis: int, 3193 inner_axis: int) -> Tuple[ops.Tensor, tensor_shape.TensorShape]: 3194 """Merge the inner shape of a DynamicRaggedShape.""" 3195 prefix = inner_shape[:outer_axis] 3196 suffix = inner_shape[inner_axis + 1:] 3197 3198 internal = inner_shape[outer_axis:inner_axis + 1] 3199 internal_value = [_reduce_prod_patch(internal)] 3200 new_internal = array_ops.concat([prefix, internal_value, suffix], axis=0) 3201 prefix_static = static_inner_shape[:outer_axis] 3202 suffix_static = static_inner_shape[inner_axis+1:] 3203 internal_static = static_inner_shape[outer_axis:inner_axis+1] 3204 internal_value_static = tensor_shape.TensorShape( 3205 [internal_static.num_elements()]) 3206 new_internal_static = prefix_static + internal_value_static + suffix_static 3207 3208 return (new_internal, new_internal_static) 3209 3210 3211def _batch_rp_spec(rp_spec: RowPartitionSpec, 3212 batch_size: Optional[int]) -> RowPartitionSpec: 3213 """Batches a RowPartitionSpec. 3214 3215 Given a RowPartitionSpec and a batch_size, create a RowPartitionSpec that 3216 will be the spec for the concatenation of batch_size RowPartitions. 3217 3218 A RowPartition can be considered a transformation from a list of a given 3219 length to a list of lists. Assume rp_a is a map from list_a to nlist_a, 3220 And rp_b is a map from list_b to nlist_b. concat(rp_a, rp_b) is a 3221 transform of concat(list_a, list_b) to concat(nlist_a, nlist_b). 3222 3223 If batch_size is None, then have the spec be able to handle an arbitrary 3224 number of RowPartitions. 3225 3226 Args: 3227 rp_spec: a RowPartitionSpec for all the RowPartitions to be concatenated. 3228 batch_size: the number of rp_specs to be concatenated. 3229 Returns: 3230 a batched RowPartitionSpec. 3231 """ 3232 if batch_size is None: 3233 return RowPartitionSpec(uniform_row_length=rp_spec.uniform_row_length, 3234 dtype=rp_spec.dtype) 3235 nrows = None if rp_spec.nrows is None else rp_spec.nrows * batch_size 3236 nvals = None if rp_spec.nvals is None else rp_spec.nvals * batch_size 3237 return RowPartitionSpec( 3238 nrows=nrows, nvals=nvals, uniform_row_length=rp_spec.uniform_row_length, 3239 dtype=rp_spec.dtype) 3240 3241 3242def _batch_rp_spec_head(old_head: RowPartitionSpec, 3243 batch_size: Optional[int]) -> RowPartitionSpec: 3244 """Creates a RowPartitionSpec representing the new dimension created.""" 3245 nvals = None if (old_head.nrows is None or 3246 batch_size is None) else batch_size * old_head.nrows 3247 return RowPartitionSpec( 3248 nrows=batch_size, nvals=nvals, uniform_row_length=old_head.nrows, 3249 dtype=old_head.dtype) 3250 3251 3252def _batch_static_inner_shape( 3253 old_shape: tensor_shape.TensorShape, 3254 batch_size: Optional[int]) -> tensor_shape.TensorShape: 3255 """Returns a copy of old_shape with axis=0 multiplied by batch_size. 3256 3257 Only use if this is the inner_shape of a DynamicRaggedShape.Spec with one 3258 or more row partitions. 3259 3260 Args: 3261 old_shape: the original inner_shape. 3262 batch_size: the batch size. 3263 3264 Returns: 3265 a new shape. 3266 """ 3267 head_dim = tensor_shape.dimension_at_index(old_shape, 0) * batch_size 3268 return head_dim + old_shape[1:] 3269 3270 3271def _batch_tensor_shape(old_shape: tensor_shape.TensorShape, 3272 batch_size: int) -> tensor_shape.TensorShape: 3273 return tensor_shape.TensorShape([batch_size]) + old_shape 3274 3275 3276def _unbatch_static_inner_shape( 3277 old_shape: tensor_shape.TensorShape, 3278 batch_size: Optional[int]) -> tensor_shape.TensorShape: 3279 """Unbatch a static_inner_shape when num_row_partitions > 0.""" 3280 head_dim = tensor_shape.dimension_at_index(old_shape, 0) // batch_size 3281 return head_dim + old_shape[1:] 3282 3283 3284# Copied from ragged_array_ops.py 3285def ones(shape: DynamicRaggedShape, 3286 dtype=dtypes.float32, 3287 name: Optional[str] = None) -> ragged_tensor.RaggedOrDense: 3288 """Returns ones shaped like x.""" 3289 flat_values = array_ops.ones(shape.inner_shape, dtype=dtype, name=name) 3290 return ragged_tensor.RaggedTensor._from_nested_row_partitions( # pylint: disable=protected-access 3291 flat_values, shape.row_partitions) 3292