1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""StructuredTensor array ops.""" 16 17from typing import Sequence 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import random_ops 25from tensorflow.python.ops.ragged import dynamic_ragged_shape 26from tensorflow.python.ops.ragged import ragged_tensor 27from tensorflow.python.ops.ragged.row_partition import RowPartition 28from tensorflow.python.ops.structured.structured_tensor import StructuredTensor 29from tensorflow.python.util import deprecation 30from tensorflow.python.util import dispatch 31 32 33@dispatch.dispatch_for_api(array_ops.shape_v2) 34def shape_v2(input: StructuredTensor, out_type=dtypes.int32, # pylint: disable=redefined-builtin 35 name=None) -> dynamic_ragged_shape.DynamicRaggedShape: 36 """Returns a DynamicRaggedShape containing the shape of the input.""" 37 del name 38 return input._ragged_shape.with_dtype(out_type) # pylint: disable=protected-access 39 40 41@dispatch.dispatch_for_api(array_ops.shape) 42def shape_v1(input: StructuredTensor, name=None, # pylint: disable=redefined-builtin 43 out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape: 44 """Returns a DynamicRaggedShape containing the shape of the input.""" 45 del name 46 return input._ragged_shape.with_dtype(out_type) # pylint: disable=protected-access 47 48 49@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor) 50@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim') 51def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin 52 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. 53 54 This is an implementation of tf.expand_dims for StructuredTensor. Note 55 that the `axis` must be less than or equal to rank. 56 57 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) 58 >>> tf.expand_dims(st, 0).to_pyval() 59 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] 60 >>> tf.expand_dims(st, 1).to_pyval() 61 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] 62 >>> tf.expand_dims(st, 2).to_pyval() 63 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 64 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 65 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 66 67 Args: 68 input: the original StructuredTensor. 69 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` 70 name: the name of the op. 71 dim: deprecated: use axis. 72 73 Returns: 74 a new structured tensor with larger rank. 75 76 Raises: 77 an error if `axis < -(rank + 1)` or `rank < axis`. 78 """ 79 axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim) 80 return _expand_dims_impl(input, axis, name=name) 81 82 83@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor) 84def expand_dims_v2(input, axis, name=None): # pylint: disable=redefined-builtin 85 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. 86 87 This is an implementation of tf.expand_dims for StructuredTensor. Note 88 that the `axis` must be less than or equal to rank. 89 90 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) 91 >>> tf.expand_dims(st, 0).to_pyval() 92 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] 93 >>> tf.expand_dims(st, 1).to_pyval() 94 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] 95 >>> tf.expand_dims(st, 2).to_pyval() 96 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 97 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 98 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 99 100 Args: 101 input: the original StructuredTensor. 102 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` 103 name: the name of the op. 104 105 Returns: 106 a new structured tensor with larger rank. 107 108 Raises: 109 an error if `axis < -(rank + 1)` or `rank < axis`. 110 """ 111 return _expand_dims_impl(input, axis, name=name) 112 113 114@dispatch.dispatch_for_types(array_ops.gather, StructuredTensor) 115def gather(params, 116 indices, 117 validate_indices=None, 118 name=None, 119 axis=None, 120 batch_dims=0): 121 """tf.gather for structured tensors. 122 123 Does not support (yet) checks on illegal axis values, et cetera. 124 125 Indices must be a ragged or dense tensor. 126 Args: 127 params: a structured tensor to be gathered 128 indices: a ragged tensor or tensor to gather by. 129 validate_indices: whether to validate the indices 130 name: the name of the op(s). 131 axis: the axis in params to gather on. 132 batch_dims: the number of batch dimensions. 133 134 Returns: 135 the params reorganized according to indices. 136 """ 137 if name is None: 138 name = 'gather' 139 with ops.name_scope(name): 140 if axis is None: 141 axis = batch_dims 142 axis = array_ops.get_positive_axis(axis, params.shape.rank, 143 ndims_name='params.shape.rank') 144 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( 145 indices, name='indices') 146 147 def leaf_op(p): 148 return array_ops.gather( 149 p, 150 indices, 151 validate_indices=validate_indices, 152 axis=axis, 153 batch_dims=batch_dims, 154 name=None) 155 156 return _extend_op_single(params, leaf_op) 157 158 159@dispatch.dispatch_for_types(array_ops.concat, StructuredTensor) 160def concat(values, axis, name: str = 'concat'): 161 """tf.concat for structured tensors. 162 163 Does not support (yet) checks on illegal axis values, et cetera. 164 165 Args: 166 values: a sequence of StructuredTensors. 167 axis: an axis to concatenate upon. 168 name: the name of the op(s). 169 170 Returns: 171 the params reorganized according to indices. 172 """ 173 if name is None: 174 name = 'concat' 175 _assert_concat_compatible_structured_tensors(values) 176 def leaf_op(values): 177 return array_ops.concat(values, axis) 178 # TODO(martinz): handle axis when it is a tensor. 179 axis = array_ops.get_positive_axis(axis, values[0].rank) 180 with ops.name_scope(name, 'StructuredConcat', values): 181 return _extend_op(values, leaf_op) 182 183 184@dispatch.dispatch_for_types(random_ops.random_shuffle, StructuredTensor) 185def random_shuffle(value, seed=None, name=None): 186 """Shuffle a structured tensor on the zeroth axis. 187 188 Args: 189 value: a structured tensor of rank at least one. 190 seed: the seed for shuffling. 191 name: the name for shuffle. 192 193 Returns: 194 The shuffled structured tensor. 195 """ 196 with ops.name_scope(name, 'shuffle', [value, seed]): 197 if value.rank == 0: 198 raise ValueError('Cannot shuffle a scalar StructuredTensor') 199 first_dimension = value.nrows() 200 index = random_ops.random_shuffle(math_ops.range(first_dimension), 201 seed=seed) 202 return gather(value, index, axis=0) 203 204 205@dispatch.dispatch_for_types(array_ops.size_v2, StructuredTensor) 206def size_v2(input, out_type=dtypes.int32, name=None): 207 # pylint: disable=redefined-builtin 208 """Returns the size of a tensor.""" 209 return size(input, name=name, out_type=out_type) 210 211 212# pylint: disable=protected-access 213@dispatch.dispatch_for_types(array_ops.size, StructuredTensor) 214def size(input, name=None, out_type=dtypes.int32): 215 # pylint: disable=redefined-builtin 216 """Returns the size of a tensor.""" 217 with ops.name_scope(name, 'size', [input]) as name: 218 if not input.row_partitions: 219 if input.nrows() is not None: 220 return math_ops.cast(input.nrows(), out_type) # vector. 221 else: 222 return math_ops.cast(1, out_type) # scalar. 223 # 2D and up. 224 nvals = input.row_partitions[-1].nvals() 225 if nvals is None or out_type is None: 226 return nvals 227 return math_ops.cast(nvals, dtype=out_type) 228 229 230# pylint: disable=protected-access 231@dispatch.dispatch_for_types(array_ops.zeros_like, StructuredTensor) 232def zeros_like(tensor, dtype=None, name=None, optimize=True): 233 """Implementation of zeros_like for StructuredTensor for TF v1.""" 234 del optimize 235 return zeros_like_v2(tensor, dtype=dtype, name=name) 236 237 238# pylint: disable=protected-access 239@dispatch.dispatch_for_types(array_ops.zeros_like_v2, StructuredTensor) 240def zeros_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin 241 """Replace every object with a zero. 242 243 Example: 244 >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}]) 245 >>> tf.zeros_like(st) 246 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0.0, 0.0], dtype=float32)> 247 >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]]) 248 >>> tf.zeros_like(st, dtype=tf.int32) 249 <tf.RaggedTensor [[0], [0, 0]]> 250 251 Args: 252 input: a structured tensor. 253 dtype: the dtype of the resulting zeros. (default is tf.float32) 254 name: a name for the op. 255 Returns: 256 a tensor of zeros of the same shape. 257 """ 258 if dtype is None: 259 dtype = dtypes.float32 260 with ops.name_scope(name, 'zeros_like', [input]) as name: 261 if not input.row_partitions: 262 if input.nrows() is not None: 263 return array_ops.zeros([input.nrows()], dtype) # vector. 264 else: 265 return array_ops.zeros([], dtype) # scalar. 266 # 2D and up. 267 last_row_partition = input.row_partitions[-1] 268 269 result = ragged_tensor.RaggedTensor._from_nested_row_partitions( 270 array_ops.zeros(last_row_partition.nvals(), dtype=dtype), 271 input.row_partitions) 272 return result 273 274 275# pylint: disable=protected-access 276@dispatch.dispatch_for_types(array_ops.ones_like, StructuredTensor) 277def ones_like(tensor, dtype=None, name=None, optimize=True): 278 """Implementation of zeros_like for StructuredTensor for TF v1.""" 279 del optimize 280 return ones_like_v2(tensor, dtype=dtype, name=name) 281 282 283# pylint: disable=protected-access 284@dispatch.dispatch_for_types(array_ops.ones_like_v2, StructuredTensor) 285def ones_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin 286 """Replace every object with a zero. 287 288 Example: 289 >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}]) 290 >>> tf.ones_like(st) 291 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1.0, 1.0], dtype=float32)> 292 >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]]) 293 >>> tf.ones_like(st, dtype=tf.int32) 294 <tf.RaggedTensor [[1], [1, 1]]> 295 296 Args: 297 input: a structured tensor. 298 dtype: the dtype of the resulting zeros. (default is tf.float32) 299 name: a name for the op. 300 Returns: 301 a tensor of zeros of the same shape. 302 """ 303 if dtype is None: 304 dtype = dtypes.float32 305 with ops.name_scope(name, 'ones_like', [input]) as name: 306 if not input.row_partitions: 307 if input.nrows() is not None: 308 return array_ops.ones([input.nrows()], dtype) # vector. 309 else: 310 return array_ops.ones([], dtype) # scalar. 311 # 2D and up. 312 last_row_partition = input.row_partitions[-1] 313 314 result = ragged_tensor.RaggedTensor._from_nested_row_partitions( 315 array_ops.ones(last_row_partition.nvals(), dtype=dtype), 316 input.row_partitions) 317 return result 318 319 320@dispatch.dispatch_for_types(array_ops.rank, StructuredTensor) 321def rank(input, name=None): 322 # pylint: disable=redefined-builtin 323 """Returns the rank of a tensor.""" 324 with ops.name_scope(name, 'rank', [input]) as name: 325 return constant_op.constant(input.rank, dtype=dtypes.int32) 326 327 328def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin 329 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. 330 331 This is an implementation of tf.expand_dims for StructuredTensor. Note 332 that the `axis` must be less than or equal to rank. 333 334 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) 335 >>> tf.expand_dims(st, 0).to_pyval() 336 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] 337 >>> tf.expand_dims(st, 1).to_pyval() 338 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] 339 >>> tf.expand_dims(st, 2).to_pyval() 340 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 341 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 342 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 343 344 Args: 345 st: the original StructuredTensor. 346 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` 347 name: the name of the op. 348 349 Returns: 350 a new structured tensor with larger rank. 351 352 Raises: 353 an error if `axis < -(rank + 1)` or `rank < axis`. 354 """ 355 axis = array_ops.get_positive_axis( 356 axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)') 357 with ops.name_scope(name, 'ExpandDims', [st, axis]): 358 new_fields = { 359 k: array_ops.expand_dims(v, axis) for (k, v) in st._fields.items() 360 } 361 new_shape = st.shape[:axis] + (1,) + st.shape[axis:] 362 new_row_partitions = _expand_st_row_partitions(st, axis) 363 new_nrows = st.nrows() if (axis > 0) else 1 364 return StructuredTensor.from_fields( 365 new_fields, 366 shape=new_shape, 367 row_partitions=new_row_partitions, 368 nrows=new_nrows) 369 370 371def _expand_st_row_partitions(st, axis): 372 """Create the row_partitions for expand_dims.""" 373 if axis == 0: 374 if st.shape.rank == 0: 375 return () 376 nvals = st.nrows() 377 new_partition = RowPartition.from_uniform_row_length( 378 nvals, nvals, nrows=1, validate=False) 379 return (new_partition,) + st.row_partitions 380 elif axis == st.rank: 381 nvals = ( 382 st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows()) 383 return st.row_partitions + (RowPartition.from_uniform_row_length( 384 1, nvals, nrows=nvals, validate=False),) 385 else: 386 nvals = ( 387 st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows()) 388 return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length( 389 1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:] 390 391 392# TODO(martinz): consider allowing values to be nested. 393def _extend_op(values, leaf_op, empty_st_op=None): 394 """Extend an op from RaggedTensor and Tensor to StructuredTensor. 395 396 Visits all children of the structured tensor, and children of children, 397 applying leaf_op whenever it reaches a leaf, and empty_st_op whenever 398 it reaches an internal node without children. 399 400 Args: 401 values: a list of structured tensors, ragged tensors, or tensors. All must 402 have the same type. If they are structured tensors, they must have the 403 same paths. 404 leaf_op: an op for handling non-structured tensor. 405 empty_st_op: op to create a structured tensor without fields. 406 407 Returns: 408 the result of the extended op (a StructuredTensor, RaggedTensor, or Tensor) 409 410 Raises: 411 ValueError: 412 If values is not a Sequence or is empty. 413 """ 414 if not isinstance(values, Sequence): 415 raise ValueError('Expected a list') 416 417 if not values: 418 raise ValueError('List cannot be empty') 419 420 if empty_st_op is None: 421 empty_st_op = empty_st_op_like_zeros(leaf_op) 422 # Use the structure of the first StructuredTensor. They are all assumed to 423 # be the same. 424 value = values[0] 425 426 if isinstance(value, StructuredTensor): 427 # TODO(martinz): Calling empty_st_op may add unnecessary ops. Revisit later. 428 empty_result = empty_st_op(values) 429 if not value.field_names(): 430 return empty_result 431 new_fields = {} 432 for k in value.field_names(): 433 new_fields[k] = _extend_op([v.field_value(k) for v in values], leaf_op, 434 empty_st_op) 435 return StructuredTensor.from_fields(new_fields, shape=empty_result.shape) 436 else: 437 return leaf_op(values) 438 439 440def _extend_op_single(value, leaf_op, empty_st_op=None): 441 """Extend an op to a value instead of a list of values.""" 442 443 def to_list_op(element_op): 444 if element_op is None: 445 return None 446 447 def list_op(values): 448 [value] = values 449 return element_op(value) 450 451 return list_op 452 453 return _extend_op([value], to_list_op(leaf_op), to_list_op(empty_st_op)) 454 455 456def empty_st_op_like_zeros(leaf_op): 457 458 def empty_st_op(values): 459 as_zeros = [ 460 zeros_like_v2(value, dtype=dtypes.int32) for value in values 461 ] 462 result = leaf_op(as_zeros) 463 return _structured_tensor_like(result) 464 465 return empty_st_op 466 467 468def _structured_tensor_from_dense_tensor(t): 469 """Create a structured tensor with the shape of a dense tensor.""" 470 # Note: If a tensor will have rank 0, 471 # it either has a fully defined shape or has unknown rank. 472 if t.shape.is_fully_defined(): 473 return StructuredTensor.from_fields({}, shape=t.shape) 474 elif t.shape.rank is None: 475 raise ValueError("Can't build StructuredTensor w/ unknown rank") 476 elif t.shape.rank == 1: 477 return StructuredTensor.from_fields({}, shape=t.shape, 478 nrows=array_ops.shape(t)[0]) 479 else: 480 rt = ragged_tensor.RaggedTensor.from_tensor(t) 481 return _structured_tensor_from_row_partitions(t.shape, 482 rt._nested_row_partitions) 483 484 485def _structured_tensor_from_row_partitions(shape, row_partitions): 486 return StructuredTensor.from_fields({}, 487 shape=shape, 488 row_partitions=row_partitions) 489 490 491# pylint: disable=protected_access 492def _all_nested_row_partitions(rt): 493 """Returns all nested row partitions in rt, including for dense dimensions.""" 494 if isinstance(rt, ops.Tensor): 495 if rt.shape.rank <= 1: 496 return () 497 else: 498 rt2 = ragged_tensor.RaggedTensor.from_tensor(rt) 499 return rt2._nested_row_partitions 500 else: 501 tail_partitions = _all_nested_row_partitions(rt.flat_values) 502 head_partitions = rt._nested_row_partitions # pylint: disable=protected_access 503 return head_partitions + tail_partitions 504 505 506def _structured_tensor_like(t): 507 """Create a StructuredTensor with the shape of a (composite) tensor.""" 508 if isinstance(t, ops.Tensor): 509 return _structured_tensor_from_dense_tensor(t) 510 if ragged_tensor.is_ragged(t): 511 return StructuredTensor.from_fields( 512 {}, shape=t.get_shape(), row_partitions=_all_nested_row_partitions(t)) 513 # here, it is a StructuredTensor 514 return StructuredTensor.from_fields({}, 515 shape=t.shape, 516 row_partitions=t.row_partitions, 517 nrows=t.nrows()) 518 519 520def _get_all_paths(st): 521 """Get all the paths from a StructuredTensor.""" 522 fields = st.field_names() 523 all_paths = {()} 524 for k in fields: 525 v = st.field_value(k) 526 if isinstance(v, StructuredTensor): 527 all_paths = all_paths.union([(k,) + p for p in _get_all_paths(v)]) 528 else: 529 all_paths.add((k,)) 530 return all_paths 531 532 533def _get_all_ranks(st): 534 """Get ranks of all submessages of a StructuredTensor.""" 535 fields = st.field_names() 536 all_ranks = {(): st.rank} 537 for k in fields: 538 v = st.field_value(k) 539 if isinstance(v, StructuredTensor): 540 for (k2, v2) in _get_all_ranks(v).items(): 541 all_ranks[(k,) + k2] = v2 542 return all_ranks 543 544 545def _assert_all_paths_match(values): 546 """Raises an error if the paths are not identical.""" 547 paths = [_get_all_paths(st) for st in values] 548 path_diff = set() 549 for other_paths in paths[1:]: 550 path_diff = path_diff.union(paths[0].symmetric_difference(other_paths)) 551 if path_diff: 552 raise ValueError( 553 'Some paths are present in some, but not all, structured tensors: %r' % 554 (path_diff,)) 555 556 557def _assert_all_ranks_match(values): 558 """Raises an error if the ranks of submessages are not identical.""" 559 ranks = [_get_all_ranks(st) for st in values] 560 for other_ranks in ranks[1:]: 561 if other_ranks != ranks[0]: 562 # TODO(martinz): If this becomes common, we can provide more detail. 563 # e.g.: which path is inconsistent. 564 raise ValueError('Ranks of sub-message do not match') 565 566 567def _assert_concat_compatible_structured_tensors(values): 568 """Sometimes raises an error if concat doesn't make sense statically on values. 569 570 values must be a sequence, and each element in values must be a structured 571 tensor, and must have the same paths. Additionally, each path that is a 572 submessage must have the same rank. 573 574 These constraints are sufficient for concat on the fields to be the same 575 as concat on structured tensors. This is meant to capture scenarios like 576 paths that are not in the first structured tensor, but are in later 577 structured tensors, which will just be ignored by the recursive algorithm. 578 579 If the rank of a submessage was different for two structured tensors, 580 then that is also a non-sensical merge. 581 582 Note that all of these checks are static, as paths and submessage ranks 583 are known. 584 585 Args: 586 values: a Sequence of StructuredTensors. 587 588 Raises: 589 ValueError: if there is any inconsistency as described above. 590 """ 591 if not isinstance(values, Sequence): 592 raise ValueError('values must be a list of StructuredTensors (not a list)') 593 if not values: 594 raise ValueError('values must not be an empty list') 595 for st in values: 596 if not isinstance(st, StructuredTensor): 597 raise ValueError('values must be a list of StructuredTensors') 598 _assert_all_paths_match(values) 599 _assert_all_ranks_match(values) 600