1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Structured Tensors.""" 16 17import re 18from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union 19 20import numpy as np 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import extension_type 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.framework import type_spec 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import check_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops.ragged import dynamic_ragged_shape 34from tensorflow.python.ops.ragged import ragged_factory_ops 35from tensorflow.python.ops.ragged import ragged_tensor 36from tensorflow.python.ops.ragged.row_partition import RowPartition 37from tensorflow.python.util import compat 38from tensorflow.python.util import nest 39 40 41class StructuredTensor(extension_type.BatchableExtensionType): 42 """A multidimensional collection of structures with the same schema. 43 44 A **`StructuredTensor`** is a multi-dimensional collection of ***structures*** 45 with the same ***schema***, where: 46 47 * A ***schema*** is a collection of fields, each of which has a name and type. 48 * A ***structure*** maps each field in the schema to a tensor value (which 49 could be a nested StructuredTensor). 50 51 As an important special case, a 1D `StructuredTensor` encodes a 2D table, 52 where columns are heterogeneous `Tensor`s, and rows are the aligned elements 53 in each of those `Tensor`s. 54 55 Internally, StructuredTensors use a "field-major" encoding: for each leaf 56 field, there is a single tensor that stores the value of that field for all 57 structures in the `StructuredTensor`. 58 59 ### Examples 60 61 >>> # A scalar StructuredTensor describing a single person. 62 >>> s1 = StructuredTensor.from_pyval( 63 ... {"age": 82, "nicknames": ["Bob", "Bobby"]}) 64 >>> s1.shape 65 TensorShape([]) 66 >>> s1["age"] 67 <tf.Tensor: shape=(), dtype=int32, numpy=82> 68 69 >>> # A vector StructuredTensor describing three people. 70 >>> s2 = StructuredTensor.from_pyval([ 71 ... {"age": 12, "nicknames": ["Josaphine"]}, 72 ... {"age": 82, "nicknames": ["Bob", "Bobby"]}, 73 ... {"age": 42, "nicknames": ["Elmo"]}]) 74 >>> s2.shape 75 TensorShape([3]) 76 >>> s2[0]["age"] 77 <tf.Tensor: shape=(), dtype=int32, numpy=12> 78 79 80 ### Field Paths 81 82 A *field path* is a tuple of field names, specifying the path to a nested 83 field. 84 """ 85 _fields: Mapping[str, Union[ops.Tensor, ragged_tensor.RaggedTensor, 86 'StructuredTensor', extension_type.ExtensionType]] 87 _ragged_shape: dynamic_ragged_shape.DynamicRaggedShape 88 89 __name__ = 'tf.StructuredTensor' 90 #============================================================================= 91 # Common Types 92 #============================================================================= 93 # pylint: disable=invalid-name 94 # Field names work as key, and they can be a sequence to refer to the 95 # sub-levels (embedded) StructuredTensor's. 96 FieldName = Union[str, Sequence[str]] 97 98 # Each field may contain one of the following types of Tensors. 99 FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor'] 100 101 # Function that takes a FieldValue as input and returns the transformed 102 # FieldValue. 103 FieldFn = Callable[[FieldValue], FieldValue] 104 105 # pylint: enable=invalid-name 106 107 #============================================================================= 108 # Constructor & Factory Methods 109 #============================================================================= 110 def __init__(self, fields: Mapping[str, FieldValue], 111 ragged_shape: dynamic_ragged_shape.DynamicRaggedShape): 112 self._fields = fields 113 self._ragged_shape = ragged_shape 114 115 @classmethod 116 def _old_init(cls, fields, shape, nrows, row_partitions, internal=False): 117 """Private constructor -- use factory methods to create StructuredTensors. 118 119 This constructor builds a `StructuredTensor` from the given attributes, 120 performing minimal validation. 121 122 Args: 123 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 124 `StructuredTensor`. (This dict is not copied, so the caller must ensure 125 that it does not get mutated via leaked references.) 126 shape: `tf.TensorShape` with statically known rank. 127 nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`. 128 row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`. 129 internal: ignored argument. 130 Returns: 131 a StructuredTensor. 132 """ 133 assert isinstance(fields, dict), fields 134 assert isinstance(shape, tensor_shape.TensorShape), shape 135 assert nrows is None or isinstance(nrows, ops.Tensor), nrows 136 assert row_partitions is None or isinstance(row_partitions, 137 tuple), row_partitions 138 return StructuredTensor( 139 fields=fields, 140 ragged_shape=_dynamic_ragged_shape_init(fields, shape, nrows, 141 row_partitions)) 142 143 @classmethod 144 def from_shape( 145 cls, ragged_shape: dynamic_ragged_shape.DynamicRaggedShape 146 ) -> 'StructuredTensor': 147 """Creates a `StructuredTensor` with no fields and ragged_shape. 148 149 Args: 150 ragged_shape: the shape of the structured tensor. 151 152 Returns: 153 a StructuredTensor with no fields and ragged_shape. 154 """ 155 return StructuredTensor(fields={}, ragged_shape=ragged_shape) 156 157 @classmethod 158 def from_fields(cls, 159 fields, 160 shape=(), 161 nrows=None, 162 row_partitions=None, 163 validate=False): 164 """Creates a `StructuredTensor` from a dictionary of fields. 165 166 Args: 167 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 168 `StructuredTensor`, providing the values for individual fields in each 169 structure. If `shape.rank > 0`, then every tensor in `fields` must have 170 the same shape in the first `shape.rank` dimensions; and that shape must 171 be compatible with `shape`; and `result[i1...iN][key] = 172 fields[key][i1...iN]` (where `N==shape.rank`). 173 shape: A `TensorShape`: static information about the shape of the 174 `StructuredTensor`. Must have a known `rank`. Defaults to scalar shape 175 (i.e. `rank=0`). 176 nrows: scalar integer tensor containing the number of rows in this 177 `StructuredTensor`. Should only be specified if `shape.rank > 0`. 178 Default value is inferred from the `fields` values. If `fields` is 179 empty, then this must be specified. 180 row_partitions: A list of `RowPartition`s describing the (possibly ragged) 181 shape of this `StructuredTensor`. Should only be specified if 182 `shape.rank > 1`. Default value is inferred from the `fields` values. 183 If `fields` is empty, then this must be specified. 184 validate: If true, then add runtime validation ops that check that the 185 field values all have compatible shapes in the outer `shape.rank` 186 dimensions. 187 188 Returns: 189 A `StructuredTensor`. 190 191 Examples: 192 193 >>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]}) 194 <StructuredTensor( 195 fields={ 196 "x": tf.Tensor(1, shape=(), dtype=int32), 197 "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)}, 198 shape=())> 199 200 >>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]}, 201 ... shape=[2]) 202 <StructuredTensor( 203 fields={ 204 "bar": tf.Tensor([3 4], shape=(2,), dtype=int32), 205 "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)}, 206 shape=(2,))> 207 """ 208 shape = tensor_shape.as_shape(shape) 209 rank = shape.rank 210 if rank is None: 211 raise ValueError("StructuredTensor's shape must have known rank.") 212 if not isinstance(fields, dict): 213 raise TypeError('fields must be a dictionary, got %s' % 214 type(fields).__name__) 215 if rank < 2 and row_partitions: 216 raise ValueError('row_partitions must be None or [] if shape.rank<2') 217 if rank == 0 and nrows is not None: 218 raise ValueError('nrows must be None if shape.rank==0') 219 if row_partitions is not None: 220 row_partitions = tuple(row_partitions) 221 if len(row_partitions) != max(0, rank - 1): 222 raise ValueError('len(row_partitions) must be shape.rank-1') 223 elif rank < 2: 224 row_partitions = () 225 226 fields = dict(fields) # Make a private copy. 227 with ops.name_scope(None, 'StructuredTensor', fields.values()): 228 # TODO(martinz): Make this have better errors. 229 shape = _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions) 230 231 # TODO(martinz): This may not need to be done if all fields are dense. 232 if shape.rank > 1: 233 shape = shape._with_num_row_partitions(shape.rank - 1) 234 235 # Validate keys and convert field values to tensors. 236 for key, value in fields.items(): 237 if not isinstance(key, str): 238 239 raise TypeError( 240 f'Unexpected type for key in `fields`: {key}') 241 if not _FIELD_NAME_RE.match(key): 242 raise ValueError('Field name %r is not currently allowed.' % key) 243 fields[key] = _convert_to_structured_field_value(value) 244 245 fields = dict([(k, _replace_row_partitions(v, row_partitions)) 246 for (k, v) in fields.items()]) 247 return cls(fields=fields, ragged_shape=shape) 248 249 @classmethod 250 def from_fields_and_rank(cls, fields, rank, validate=False): 251 """Creates a `StructuredTensor` from a nonempty dictionary of fields. 252 253 Args: 254 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 255 `StructuredTensor`, providing the values for individual fields in each 256 structure. If `rank > 0`, then every tensor in `fields` must have 257 the same shape in the first `rank` dimensions. Cannot be empty. 258 rank: The rank of the resulting structured tensor. 259 validate: If true, then add runtime validation ops that check that the 260 field values all have compatible shapes in the outer `rank` 261 dimensions. 262 263 Returns: 264 A `StructuredTensor`. 265 Examples: 266 >>> StructuredTensor.from_fields_and_rank({'x': 1, 'y': [1, 2, 3]}, 0) 267 <StructuredTensor( 268 fields={ 269 "x": tf.Tensor(1, shape=(), dtype=int32), 270 "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)}, 271 shape=())> 272 >>> StructuredTensor.from_fields_and_rank({'foo': [1, 2], 'bar': [3, 4]}, 273 ... 1) 274 <StructuredTensor( 275 fields={ 276 "bar": tf.Tensor([3 4], shape=(2,), dtype=int32), 277 "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)}, 278 shape=(2,))> 279 """ 280 if not fields: 281 raise ValueError('Must provide at least one field') 282 if not isinstance(rank, int): 283 raise ValueError('rank must be an integer') 284 if rank < 0: 285 raise ValueError('rank must be nonnegative') 286 fields = { 287 k: _convert_to_structured_field_value(v) for (k, v) in fields.items() 288 } 289 dtype = _find_shape_dtype(fields, None, None) 290 291 shape = _shape_from_fields(fields, rank, dtype) 292 if rank > 1: 293 shape = shape._with_num_row_partitions(rank - 1) 294 new_rp = shape._row_partitions # pylint: disable=protected-access 295 fields = { 296 k: _replace_row_partitions(v, new_rp) for (k, v) in fields.items() 297 } 298 return StructuredTensor(fields=fields, ragged_shape=shape) 299 300 def with_updates( 301 self, 302 updates: Dict[FieldName, Union[FieldValue, FieldFn, None]], 303 validate: bool = False 304 ) -> 'StructuredTensor': 305 """Creates a new `StructuredTensor` with the updated fields. 306 307 If this `StructuredTensor` is a scalar, and `k` is the `FieldName` being 308 updated and `v` the new value, then: 309 310 ``` 311 result[k] = v # If (k, v) is in updates and v is a FieldValue 312 result[k] = f(self[k]) # If (k, f) is in updates and f is a FieldFn 313 result[k] = self[k] # If k is in self.field_names but not in updates 314 ``` 315 316 If this `StructuredTensor` has rank `N` and shape `[D1...DN]`, then each 317 FieldValue `v` in `updates` must have shape `[D1...DN, ...]`, that is, 318 prefixed with the same shape as the `StructuredTensor`. Then the resulting 319 `StructuredTensor` will have: 320 321 ``` 322 result[i1...iN][k] = v[i1...iN] # (k, v) in updates 323 result[i1...iN][k] = f(self.field_value(k))[i1...iN] # (k, f) in updates 324 result[i1...iN][k] = self[i1...iN][k] # k not in updates 325 ``` 326 327 Note that `result.shape` is always equal to `self.shape` (but the shapes 328 of nested StructuredTensors may be changed if they are updated with new 329 values). 330 331 Args: 332 updates: A dictionary mapping `FieldName` to either a `FieldValue` to be 333 used to update, or a `FieldFn` that will transform the value for the 334 given `FieldName`. `FieldName` can be a string for a direct field, or a 335 sequence of strings to refer to a nested sub-field. `FieldFn` is a 336 function that takes a `FieldValue` as input and should return a 337 `FieldValue`. All other fields are copied over to the new 338 `StructuredTensor`. New `FieldName` can be given (to add new fields), 339 but only to existing `StructuredTensor`, it won't automatically create 340 new nested structures -- but one can create a whole `StructureTensor` 341 sub-structure and set that into an existing structure. If the new value 342 is set to `None`, it is removed. 343 validate: If true, then add runtime validation ops that check that the 344 field values all have compatible shapes in the outer `shape.rank` 345 dimensions. 346 347 Returns: 348 A `StructuredTensor`. 349 350 Raises: 351 `ValueError`: If the any of the `FieldName` keys points to non-existent 352 sub-structures, if parent and child nodes are updated, if shapes 353 change, if a delete update is given for a non-existant field, or if a 354 `FieldFn` transforming function is given for a `FieldName` that doesn't 355 yet exist. 356 357 Examples: 358 359 >>> shoes_us = StructuredTensor.from_pyval([ 360 ... {"age": 12, "nicknames": ["Josaphine"], 361 ... "shoes": {"sizes": [8.0, 7.5, 7.5]}}, 362 ... {"age": 82, "nicknames": ["Bob", "Bobby"], 363 ... "shoes": {"sizes": [11.0, 11.5, 12.0]}}, 364 ... {"age": 42, "nicknames": ["Elmo"], 365 ... "shoes": {"sizes": [9.0, 9.5, 10.0]}}]) 366 >>> def us_to_europe(t): 367 ... return tf.round(t * 2.54 + 17.0) # Rough approximation. 368 >>> shoe_sizes_key = ("shoes", "sizes") 369 >>> shoes_eu = shoes_us.with_updates({shoe_sizes_key: us_to_europe}) 370 >>> shoes_eu.field_value(shoe_sizes_key) 371 <tf.RaggedTensor [[37.0, 36.0, 36.0], [45.0, 46.0, 47.0], 372 [40.0, 41.0, 42.0]]> 373 """ 374 updates_items = [(_normalize_field_name_to_tuple(name), value) 375 for name, value in updates.items()] 376 377 # Sort by keys and check for updates of both parent and child nodes. 378 updates_items = sorted(updates_items) 379 for i in range(1, len(updates_items)): 380 # Parent of a node would precede node in the sorted order. 381 name = updates_items[i][0] # item[0] is the name, item[1] is the value. 382 prev_name = updates_items[i - 1][0] 383 if name[:len(prev_name)] == prev_name: 384 raise ValueError( 385 '`StructuredTensor.with_updates` does not allow both parent and ' 386 'child nodes to be updated: parent={}, child={}. If needed you can ' 387 'update child nodes in the parent update value.'.format( 388 prev_name, name)) 389 return self._with_updates_impl((), updates_items, validate) 390 391 def _with_updates_impl( 392 self, 393 error_prefix: Tuple[str], 394 updates: List[Tuple[FieldName, Union[FieldValue, FieldFn]]], 395 validate: bool) -> 'StructuredTensor': 396 """Recursive part of `with_updates` implementation.""" 397 # Get current fields. 398 new_fields = dict(self._fields) 399 400 # Convert field name to string with full path for error messages. 401 def name_fullpath(name: Sequence[str]) -> str: 402 return str(error_prefix + (name,)) 403 404 # Apply value if a function or the value itself. 405 def apply_value(name: str, value: Union['FieldValue', 406 'FieldFn']) -> 'FieldValue': 407 if callable(value): 408 # `value` is actually a transforming function. 409 if name not in new_fields: 410 raise ValueError( 411 '`StructuredTensor.with_updates` cannot update the field {} ' 412 'because a transforming function was given, but that field ' 413 'does not already exist.'.format(name_fullpath(name))) 414 value = value(new_fields[name]) 415 return value 416 417 # Merge updates. 418 for name, value in updates: 419 if not name or not name[0]: 420 raise ValueError( 421 '`StructuredTensor.with_updates` does not allow empty names ' 422 '{}.'.format(name_fullpath(name))) 423 424 if len(name) == 1: 425 name = name[0] 426 if value is None: 427 if name not in new_fields: 428 raise ValueError( 429 '`StructuredTensor.with_updates` cannot delete field ' 430 '{} because it is not present.'.format(name_fullpath(name))) 431 new_fields.pop(name) 432 else: 433 new_fields[name] = apply_value(name, value) 434 else: 435 # Recursive 436 prefix = name[0] 437 suffix = name[1:] 438 if prefix not in new_fields: 439 raise ValueError( 440 '`StructuredTensor.with_updates` cannot create new sub-field ' 441 '{} if parent field {} is not set.'.format( 442 error_prefix + tuple(name), name_fullpath(prefix))) 443 current_value = new_fields[prefix] 444 if not isinstance(current_value, StructuredTensor): 445 raise ValueError( 446 '`StructuredTensor.with_updates` cannot create new sub-field ' 447 '{} if parent structure {} is not a `StructuredTensor` that ' 448 'can contain sub-structures -- it is a `{}`.'.format( 449 error_prefix + tuple(name), name_fullpath(prefix), 450 type(current_value))) 451 one_update = [(suffix, value)] 452 453 # Accessing protected member in recursion. 454 # FutureWork: optimize by aggregating the recursions, instead of 455 # calling one at a time. 456 # pylint: disable=protected-access 457 value = current_value._with_updates_impl(error_prefix + (prefix,), 458 one_update, validate) 459 # pylint: enable=protected-access 460 new_fields[prefix] = value 461 462 # TODO(edloper): When validate=True, only validate the modified fields. 463 try: 464 return StructuredTensor.from_fields( 465 new_fields, 466 shape=self.shape, 467 row_partitions=self.row_partitions, 468 nrows=self.nrows(), 469 validate=validate) 470 471 except ValueError as e: 472 msg = '`StructuredTensor.with_updates` failed' 473 if error_prefix: 474 msg = '{} for field {}'.format(msg, error_prefix) 475 raise ValueError(msg) from e 476 477 def _promote_helper(self, source_path, new_parent_path): 478 """Creates a promoted field without adding it to the structure. 479 480 Args: 481 source_path: the source path in the structured tensor. 482 new_parent_path: the new parent path. Must be a prefix of source_path. 483 484 Returns: 485 a composite tensor of source_path promoted. 486 Raises: 487 ValueError: if the shape of the field is unknown and the right strategy 488 cannot be determined. 489 """ 490 current_field = self.field_value(source_path) 491 new_parent_rank = self.field_value(new_parent_path).rank 492 parent_rank = self.field_value(source_path[:-1]).rank 493 if new_parent_rank == parent_rank: 494 return current_field 495 current_field_rank = current_field.shape.rank 496 if current_field_rank is None: 497 raise ValueError('Cannot determine if dimensions should be merged.') 498 inner_dim = min(parent_rank, current_field_rank - 1) 499 if inner_dim <= new_parent_rank: 500 return current_field 501 return _merge_dims_generic(current_field, new_parent_rank, inner_dim) 502 503 def promote(self, source_path, new_name): 504 """Promotes a field, merging dimensions between grandparent and parent. 505 506 >>> d = [ 507 ... {'docs': [{'tokens':[1, 2]}, {'tokens':[3]}]}, 508 ... {'docs': [{'tokens':[7]}]}] 509 >>> st = StructuredTensor.from_pyval(d) 510 >>> st2 =st.promote(('docs','tokens'), 'docs_tokens') 511 >>> st2[0]['docs_tokens'] 512 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)> 513 >>> st2[1]['docs_tokens'] 514 <tf.Tensor: shape=(1,), dtype=int32, numpy=array([7], dtype=int32)> 515 516 Args: 517 source_path: the path of the field or substructure to promote; must have 518 length at least 2. 519 new_name: the name of the new field (must be a string). 520 521 Returns: 522 a modified structured tensor with the new field as a child of the 523 grandparent of the source_path. 524 525 Raises: 526 ValueError: if source_path is not a list or a tuple or has a length 527 less than two, or new_name is not a string, or the rank 528 of source_path is unknown and it is needed. 529 """ 530 if not isinstance(new_name, str): 531 raise ValueError('new_name is not a string') 532 if not isinstance(source_path, (list, tuple)): 533 raise ValueError('source_path must be a list or tuple') 534 535 if len(source_path) < 2: 536 raise ValueError('source_path must have length at least two') 537 538 grandparent_path = source_path[:-2] 539 new_field = self._promote_helper(source_path, grandparent_path) 540 new_path = grandparent_path + (new_name,) 541 return self.with_updates({new_path: new_field}) 542 543 #============================================================================= 544 # Properties 545 #============================================================================= 546 547 @property 548 def rank(self): 549 """The rank of this StructuredTensor. Guaranteed not to be `None`.""" 550 return self._ragged_shape.rank 551 552 @property 553 def shape(self): 554 """The static shape of this StructuredTensor. 555 556 The returned `TensorShape` is guaranteed to have a known rank, but the 557 individual dimension sizes may be unknown. 558 559 Returns: 560 `tf.TensorShape` 561 """ 562 return self._ragged_shape._to_tensor_shape() # pylint: disable=protected-access 563 564 # TODO(martinz): for backwards compatibility 565 @property 566 def _row_partitions(self): 567 """Deprecated form of row_partitions.""" 568 return self.row_partitions 569 570 # TODO(edloper): Make this a func instead of a property? Or make nrows 571 # a property instead of a func? Seems like these should be consistent. 572 @property 573 def row_partitions(self): 574 """A tuple of `RowPartition`s defining the shape of this `StructuredTensor`. 575 576 When `self.rank <= 1`, this tuple will be empty. 577 578 When `self.rank > 1`, these `RowPartitions` define the shape of the 579 `StructuredTensor` by describing how a flat (1D) list of structures can be 580 repeatedly partitioned to form a higher-dimensional object. In particular, 581 the flat list is first partitioned into sublists using `row_partitions[-1]`, 582 and then those sublists are further partitioned using `row_partitions[-2]`, 583 etc. The following examples show the row partitions used to describe 584 several different `StructuredTensor`, each of which contains 8 copies of 585 the same structure (`x`): 586 587 >>> x = {'a': 1, 'b': ['foo', 'bar', 'baz']} # shape = [] (scalar) 588 589 >>> s1 = [[x, x, x, x], [x, x, x, x]] # shape = [2, 4] 590 >>> StructuredTensor.from_pyval(s1).row_partitions 591 (tf.RowPartition(row_splits=[0 4 8]),) 592 593 >>> s2 = [[x, x], [x, x], [x, x], [x, x]] # shape = [4, 2] 594 >>> StructuredTensor.from_pyval(s2).row_partitions 595 (tf.RowPartition(row_splits=[0 2 4 6 8]),) 596 597 >>> s3 = [[x, x, x], [], [x, x, x, x], [x]] # shape = [2, None] 598 >>> StructuredTensor.from_pyval(s3).row_partitions 599 (tf.RowPartition(row_splits=[0 3 3 7 8]),) 600 601 >>> s4 = [[[x, x], [x, x]], [[x, x], [x, x]]] # shape = [2, 2, 2] 602 >>> StructuredTensor.from_pyval(s4).row_partitions 603 (tf.RowPartition(row_splits=[0 2 4]), 604 tf.RowPartition(row_splits=[0 2 4 6 8])) 605 606 607 >>> s5 = [[[x, x], [x]], [[x, x]], [[x, x], [x]]] # shape = [3, None, None] 608 >>> StructuredTensor.from_pyval(s5).row_partitions 609 (tf.RowPartition(row_splits=[0 2 3 5]), 610 tf.RowPartition(row_splits=[0 2 3 5 7 8])) 611 612 Note that shapes for nested fields (such as `x['b']` in the above example) 613 are not considered part of the shape of a `StructuredTensor`, and are not 614 included in `row_partitions`. 615 616 If this `StructuredTensor` has a ragged shape (i.e., if any of the 617 `row_partitions` is not uniform in size), then all fields will be encoded 618 as either `RaggedTensor`s or `StructuredTensor`s with these `RowPartition`s 619 used to define their outermost `self.rank` dimensions. 620 621 Returns: 622 A `tuple` of `RowPartition` objects with length `self.rank - 1` 623 (or `0` if `self.rank < 2`) 624 625 """ 626 if self.rank < 2: 627 return () 628 return self._ragged_shape._as_row_partitions() # pylint:disable=protected-access 629 630 def nrows(self): 631 """The number of rows in this StructuredTensor (if rank>0). 632 633 This means the length of the outer-most dimension of the StructuredTensor. 634 635 Notice that if `self.rank > 1`, then this equals the number of rows 636 of the first row partition. That is, 637 `self.nrows() == self.row_partitions[0].nrows()`. 638 639 Otherwise `self.nrows()` will be the first dimension of the field values. 640 641 Returns: 642 A scalar integer `Tensor` (or `None` if `self.rank == 0`). 643 """ 644 if self.rank == 0: 645 return None 646 return self._ragged_shape[0] 647 648 def _is_eager(self): 649 """True if all fields are composed of eager tensors.""" 650 tensors = nest.flatten(self, expand_composites=True) 651 return all(isinstance(t, ops.EagerTensor) for t in tensors) 652 653 #============================================================================= 654 # Encoding 655 #============================================================================= 656 657 def field_names(self): 658 """Returns the string field names for this `StructuredTensor`.""" 659 return tuple(self._fields.keys()) 660 661 def field_value(self, field_name): 662 """Returns the tensor value for the specified field or path. 663 664 If `field_name` is a `string`, then it names a field directly owned by this 665 `StructuredTensor`. If this `StructuredTensor` has shape `[D1...DN]`, then 666 the returned tensor will have shape `[D1...DN, V1...VM]`, where the slice 667 `result[d1...dN]` contains the field value for the structure at 668 `self[d1...dN]`. 669 670 If `field_name` is a `tuple` of `string`, then it specifies a path to a 671 field owned by nested `StructuredTensor`. In particular, 672 `struct.field_value((f1, f2, ..., fN))` is equivalent to 673 `struct.field_value(f1).field_value(f2)....field_value(fN)` 674 675 Args: 676 field_name: `string` or `tuple` of `string`: The field whose values should 677 be returned. 678 679 Returns: 680 `Tensor`, `StructuredTensor`, or `RaggedTensor`. 681 682 Raises: 683 KeyError: If the given field_name is not found. 684 """ 685 if isinstance(field_name, (list, tuple)): 686 value = self 687 for f in field_name: 688 if not isinstance(value, StructuredTensor): 689 raise KeyError('Field path {} not found in {}'.format( 690 field_name, self)) 691 value = value.field_value(f) 692 return value 693 return self._fields[field_name] 694 695 #============================================================================= 696 # Operators 697 #============================================================================= 698 699 # TODO(edloper): Add support for ellipsis and/or newaxis? 700 def __getitem__(self, key): 701 """Returns the specified piece of this StructuredTensor. 702 703 * If `struct_tensor` is scalar (i.e., a single structure), then 704 `struct_tensor[f]` returns the value of field `f` (where `f` must be a 705 string). 706 707 * If `struct_tensor` is non-scalar (i.e., a vector or higher-dimensional 708 tensor of structures), `struct_tensor[i]` selects an element or slice of 709 the tensor using standard Python semantics (e.g., negative values index 710 from the end). `i` may have any of the following types: 711 712 * `int` constant 713 * `string` constant 714 * scalar integer `Tensor` 715 * `slice` containing integer constants and/or scalar integer 716 `Tensor`s 717 718 #### Multidimensional indexing 719 720 `StructuredTensor` supports multidimensional indexing. I.e., `key` may be a 721 `tuple` of values, indexing or slicing multiple dimensions at once. For 722 example, if `people` is a vector of structures, each of which has a vector- 723 valued `names` field, then `people[3, 'names', 0]` is equivalent to 724 `people[3]['names'][0]`; and `people[:, 'names', :]` will return a (possibly 725 ragged) matrix of names, with shape `[num_people, num_names_per_person]`. 726 727 Args: 728 key: Indicates which piece of the StructuredTensor to return. 729 730 Returns: 731 A `Tensor`, `StructuredTensor`, or `RaggedTensor`. 732 """ 733 if isinstance(key, list): 734 key = tuple(key) 735 elif not isinstance(key, tuple): 736 key = (key,) 737 if not key: 738 return self 739 740 if self.rank == 0: 741 return self._scalar_getitem(key) 742 else: 743 return self._tensor_getitem(key) 744 745 def _scalar_getitem(self, key): 746 if (isinstance(key[0], slice) and key[0].start is None and 747 key[0].stop is None and key[0].step is None): 748 fields = dict((field_name, field_value.__getitem__(key[1:])) 749 for (field_name, field_value) in self._fields.items()) 750 return StructuredTensor.from_fields(fields, self.shape) 751 752 elif not isinstance(key[0], compat.bytes_or_text_types): 753 raise ValueError('Key for indexing a StructuredTensor must be a ' 754 "string or a full slice (':')") 755 756 return self._fields[key[0]].__getitem__(key[1:]) 757 758 def _tensor_getitem(self, key): 759 rank = self.rank 760 if len(key) <= rank: 761 new_fields = dict((field_name, field_value.__getitem__(key)) 762 for (field_name, field_value) in self._fields.items()) 763 result_shape = self.shape.as_list() 764 for d, k in enumerate(key): 765 if isinstance(k, slice): 766 if not (k.start is None and k.stop is None and k.step is None): 767 # TODO(edloper): Better static shape analysis here. 768 result_shape[d] = None 769 elif isinstance(k, (int, ops.Tensor)): 770 result_shape[d] = -1 # mark for deletion 771 elif k is None: 772 raise ValueError('Slicing not supported for tf.newaxis') 773 else: 774 # Ellipsis, tf.newaxis: 775 raise ValueError('Slicing not supported for %r' % k) 776 result_shape = [d for d in result_shape if d != -1] 777 return StructuredTensor.from_fields(new_fields, result_shape) 778 779 else: 780 if not isinstance(key[rank], compat.bytes_or_text_types): 781 # TODO(edloper): Also support full slice here? 782 raise ValueError('Key for indexing a StructuredTensor must be a string') 783 return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:]) 784 785 def __repr__(self): 786 fields = sorted(self._fields.items()) 787 fields = ((k, str(v).replace('\n', '\n ')) for k, v in fields) 788 fields = ('"{}": {}'.format(k, v) for k, v in fields) 789 dict_repr = ',\n '.join(fields) 790 return ('<StructuredTensor(\n' 791 ' fields={\n' 792 ' %s},\n' 793 ' shape=%s)>' % (dict_repr, self.shape)) 794 795 #============================================================================= 796 # Conversion 797 #============================================================================= 798 799 def to_pyval(self): 800 """Returns this StructuredTensor as a nested Python dict or list of dicts. 801 802 Converts this `StructuredTensor` to a nested python value: 803 804 * `StructTensors` with `rank=0` are converted into a dictionary, with an 805 entry for each field. Field names are used as keys and field values are 806 converted to python values. In particular: 807 808 * Scalar Tensor fields are converted to simple values (such as 809 `int` or `float` or `string`) 810 * Non-scalar Tensor fields and RaggedTensor fields are converted to 811 nested lists of simple values. 812 * StructuredTensor fields are converted recursively using `to_pyval`. 813 814 * `StructTensors` with `rank>0` are converted to nested python `list`s, 815 containing one dictionary for each structure (where each structure's 816 dictionary is defined as described above). 817 818 Requires that all fields are Eager tensors. 819 820 >>> StructuredTensor.from_fields( 821 ... {'a': [1, 2, 3]}, [3]).to_pyval() 822 [{'a': 1}, {'a': 2}, {'a': 3}] 823 824 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`. 825 826 Returns: 827 A nested Python dict or list of dicts. 828 """ 829 if not self._is_eager(): 830 raise ValueError( 831 'StructuredTensor.to_pyval() is only supported in eager mode.') 832 833 # Convert each field value to a nested list. 834 result = {} 835 for (key, value) in self._fields.items(): 836 if isinstance(value, ops.EagerTensor): 837 value = value.numpy() 838 if isinstance(value, np.ndarray): 839 value = value.tolist() 840 elif isinstance(value, ragged_tensor.RaggedTensor): 841 value = value.to_list() 842 elif isinstance(value, StructuredTensor): 843 value = value.to_pyval() 844 # TODO(edloper): Throw an exception if value is an unexpected type. 845 result[key] = value 846 847 # If rank>0, then re-group each value from dict-of-list to list-of-dict. 848 if len(self.shape) > 0: # pylint: disable=g-explicit-length-test 849 if not result: # special-case for StructuredTensors w/ no fields. 850 return _empty_dict_pylist_from_row_partitions(self.row_partitions, 851 self.nrows()) 852 return _pyval_field_major_to_node_major( 853 list(result.keys()), list(result.values()), self.rank) 854 else: 855 return result 856 857 @classmethod 858 def from_pyval(cls, pyval, typespec=None): 859 """Constructs a StructuredTensor from a nested Python structure. 860 861 >>> StructuredTensor.from_pyval( 862 ... {'a': [1, 2, 3], 'b': [[4, 5], [6, 7]]}) 863 <StructuredTensor( 864 fields={ 865 "a": tf.Tensor([1 2 3], shape=(3,), dtype=int32), 866 "b": <tf.RaggedTensor [[4, 5], [6, 7]]>}, 867 shape=())> 868 869 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`. 870 871 Args: 872 pyval: The nested Python structure that should be used to create the new 873 `StructuredTensor`. 874 typespec: A `StructuredTensor.Spec` specifying the expected type for each 875 field. If not specified, then all nested dictionaries are turned into 876 StructuredTensors, and all nested lists are turned into Tensors (if 877 rank<2) or RaggedTensors (if rank>=2). 878 879 Returns: 880 A `StructuredTensor`. 881 """ 882 return cls._from_pyval(pyval, typespec, ()) 883 884 @classmethod 885 def _from_pyval(cls, pyval, typespec, path_so_far): 886 """Helper function for from_pyval. 887 888 889 Args: 890 pyval: The nested Python structure that should be used to create the new 891 `StructuredTensor`. 892 typespec: A `StructuredTensor.Spec` specifying the expected type for each 893 field. If not specified, then all nested dictionaries are turned into 894 StructuredTensors, and all nested lists are turned into Tensors (if 895 rank<2) or RaggedTensors (if rank>=2). 896 path_so_far: the path of fields that led here (for error messages). 897 898 Returns: 899 A `StructuredTensor`. 900 """ 901 if isinstance(pyval, dict): 902 return cls._from_pydict(pyval, typespec, path_so_far) 903 elif isinstance(pyval, (list, tuple)): 904 keys = set() 905 rank = _pyval_find_struct_keys_and_depth(pyval, keys) 906 if rank is not None: 907 return cls._from_pylist_of_dict(pyval, keys, rank, typespec, 908 path_so_far) 909 else: 910 return cls._from_pylist_of_value(pyval, typespec, path_so_far) 911 else: 912 return cls._from_pyscalar(pyval, typespec, path_so_far) 913 914 @classmethod 915 def _from_pydict(cls, pyval, typespec, path_so_far): 916 """Converts python dictionary `pyval` to a StructuredTensor with rank=0.""" 917 if typespec is None: 918 fields = dict((k, cls._from_pyval(v, None, path_so_far + (k,))) 919 for (k, v) in pyval.items()) 920 else: 921 spec_shape = typespec._shape # pylint: disable=protected-access 922 field_specs = typespec._field_specs # pylint: disable=protected-access 923 if not (isinstance(typespec, StructuredTensor.Spec) and 924 spec_shape.rank == 0 and set(pyval) == set(field_specs)): 925 raise ValueError('Value at %r does not match typespec: %r vs %r' % 926 (path_so_far, pyval, typespec)) 927 fields = dict((k, cls._from_pyval(v, field_specs[k], path_so_far + (k,))) 928 for (k, v) in pyval.items()) 929 return StructuredTensor.from_fields(fields=fields, shape=(), validate=False) 930 931 @classmethod 932 def _from_pylist_of_dict(cls, pyval, keys, rank, typespec, path_so_far): 933 """Converts python list `pyval` to a StructuredTensor with rank>1.""" 934 fields = dict((key, []) for key in keys) 935 for child in pyval: 936 _pyval_update_fields(child, fields, 1) 937 if typespec is None: 938 shape = tensor_shape.TensorShape([None] * rank) 939 for (key, target) in fields.items(): 940 fields[key] = cls._from_pyval(target, None, path_so_far + (key,)) 941 else: 942 field_specs = typespec._fields # pylint: disable=protected-access 943 if ((not isinstance(typespec, StructuredTensor.Spec)) or # pylint: disable=superfluous-parens 944 (set(fields) - set(field_specs))): 945 raise ValueError('Value at %r does not match typespec: %r vs %r' % 946 (path_so_far, pyval, typespec)) 947 shape = typespec._shape 948 if shape.rank < rank: 949 raise ValueError('Value at %r does not match typespec (rank mismatch): ' 950 '%r vs %r' % (path_so_far, pyval, typespec)) 951 for (key, spec) in field_specs.items(): 952 fields[key] = cls._from_pyval( 953 fields.get(key, []), spec, path_so_far + (key,)) 954 try: 955 if not fields and typespec is None: 956 # TODO(b/183245576): handle cases where the typespec is known 957 # but the dictionary is empty. 958 return StructuredTensor._from_pylist_of_empty_dict(pyval, rank) 959 return StructuredTensor.from_fields( 960 fields=fields, shape=shape, validate=False) 961 except Exception as exc: 962 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 963 964 @classmethod 965 def _from_pylist_of_empty_dict(cls, pyval, rank): 966 """Converts a pylist of empty dictionaries to StructuredTensors.""" 967 if rank == 0: 968 return StructuredTensor.from_fields(fields={}, shape=(), validate=False) 969 elif rank == 1: 970 nrows = len(pyval) 971 shape = (nrows,) 972 return StructuredTensor.from_fields(fields={}, shape=shape, nrows=nrows) 973 elif rank > 1: 974 ragged_zeros = ragged_factory_ops.constant(_dicts_to_zeros(pyval)) 975 nrows = len(pyval) 976 shape = tensor_shape.TensorShape([len(pyval)] + ([None] * (rank - 1))) 977 return StructuredTensor.from_fields( 978 fields={}, 979 shape=shape, 980 row_partitions=ragged_zeros._nested_row_partitions, # pylint:disable=protected-access 981 nrows=nrows) 982 983 @classmethod 984 def _from_pylist_of_value(cls, pyval, typespec, path_so_far): 985 """Converts python list `pyval` to a Tensor or RaggedTensor with rank>1.""" 986 if typespec is None: 987 try: 988 return ragged_factory_ops.constant(pyval) 989 except Exception as exc: 990 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 991 elif isinstance(typespec, tensor_spec.TensorSpec): 992 try: 993 result = constant_op.constant(pyval, typespec.dtype) 994 except Exception as exc: 995 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 996 if not typespec.shape.is_compatible_with(result.shape): 997 raise ValueError('Value at %r does not match typespec: %r vs %r' % 998 (path_so_far, typespec, pyval)) 999 return result 1000 elif isinstance(typespec, ragged_tensor.RaggedTensorSpec): 1001 # pylint: disable=protected-access 1002 try: 1003 return ragged_factory_ops.constant( 1004 pyval, 1005 dtype=typespec._dtype, 1006 ragged_rank=typespec._ragged_rank, 1007 row_splits_dtype=typespec._row_splits_dtype, 1008 inner_shape=typespec._shape[typespec._ragged_rank + 1:]) 1009 except Exception as exc: 1010 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 1011 elif isinstance(typespec, StructuredTensor.Spec): 1012 empty_rank = _pyval_empty_list_depth(pyval) 1013 if empty_rank is None: 1014 raise ValueError('Value at %r does not match typespec: %r vs %r' % 1015 (path_so_far, typespec, pyval)) 1016 else: 1017 return cls._from_pylist_of_dict(pyval, set(), empty_rank, typespec, 1018 path_so_far) 1019 else: 1020 raise ValueError('Value at %r does not match typespec: %r vs %r' % 1021 (path_so_far, typespec, pyval)) 1022 1023 @classmethod 1024 def _from_pyscalar(cls, pyval, typespec, path_so_far): 1025 """Converts python scalar value `pyval` to a Tensor.""" 1026 if typespec is None: 1027 try: 1028 return constant_op.constant(pyval) 1029 except Exception as exc: 1030 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 1031 else: 1032 if not (isinstance(typespec, tensor_spec.TensorSpec) and 1033 typespec.shape.rank == 0): 1034 raise ValueError('Value at %r does not match typespec: %r vs %r' % 1035 (path_so_far, typespec, pyval)) 1036 # TODO(edloper): Check that typespec.shape matches. 1037 return constant_op.constant(pyval, typespec.dtype) 1038 1039 #============================================================================= 1040 # Transforms 1041 #============================================================================= 1042 1043 # TODO(edloper): Add a 'validate' option here? 1044 # TODO(edloper): Unify nomenclature with RaggedTensor. Should RaggedTensor 1045 # have a partition_outer_dimension method? 1046 def partition_outer_dimension(self, row_partition): 1047 """Partitions the outer dimension of this StructuredTensor. 1048 1049 Returns a new `StructuredTensor` with the same values as `self`, where 1050 the outer dimension is partitioned into two (possibly ragged) dimensions. 1051 Requires that this StructuredTensor have an outer dimension (i.e., 1052 `self.shape.rank > 0`). 1053 1054 >>> st = StructuredTensor.from_pyval( 1055 ... [{'foo': 12}, {'foo': 33}, {'foo': 99}]) 1056 >>> partition = RowPartition.from_row_lengths([2, 0, 1]) 1057 >>> st.partition_outer_dimension(partition) 1058 <StructuredTensor( 1059 fields={ 1060 "foo": <tf.RaggedTensor [[12, 33], [], [99]]>}, 1061 shape=(3, None))> 1062 1063 Args: 1064 row_partition: A `RowPartition`. 1065 1066 Returns: 1067 A `StructuredTensor` with rank `values.rank + 1`. 1068 """ 1069 if not isinstance(row_partition, RowPartition): 1070 raise TypeError('row_partition must be a RowPartition.') 1071 if self.shape.rank == 0: 1072 raise ValueError('Shape %s must have rank at least 1' % self.shape) 1073 return _partition_outer_dimension(self, row_partition) 1074 1075 def merge_dims(self, outer_axis, inner_axis): 1076 """Merges outer_axis...inner_axis into a single dimension. 1077 1078 Returns a copy of this RaggedTensor with the specified range of dimensions 1079 flattened into a single dimension, with elements in row-major order. 1080 1081 >>> st = StructuredTensor.from_pyval( 1082 ... [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]]) 1083 >>> st.merge_dims(0, 1) 1084 <StructuredTensor( 1085 fields={ 1086 "foo": tf.Tensor([12 33 99], shape=(3,), dtype=int32)}, 1087 shape=(3,))> 1088 1089 Args: 1090 outer_axis: `int`: The first dimension in the range of dimensions to 1091 merge. May be negative (to index from the last dimension). 1092 inner_axis: `int`: The last dimension in the range of dimensions to merge. 1093 May be negative (to index from the last dimension). 1094 1095 Returns: 1096 A copy of this tensor, with the specified dimensions merged into a 1097 single dimension. The shape of the returned tensor will be 1098 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` 1099 is the total number of slices in the merged dimensions. 1100 """ 1101 outer_axis = array_ops.get_positive_axis( 1102 outer_axis, 1103 self.shape.rank, 1104 axis_name='outer_axis', 1105 ndims_name='rank(self)') 1106 inner_axis = array_ops.get_positive_axis( 1107 inner_axis, 1108 self.shape.rank, 1109 axis_name='inner_axis', 1110 ndims_name='rank(self)') 1111 if not outer_axis <= inner_axis: 1112 raise ValueError('Expected outer_axis (%d) to be less than or equal to ' 1113 'inner_axis (%d)' % (outer_axis, inner_axis)) 1114 return _merge_dims(self, outer_axis, inner_axis) 1115 1116 class Spec: 1117 """A spec for StructuredTensor.""" 1118 1119 def __validate__(self): 1120 assert self._ragged_shape is not None 1121 1122 @classmethod 1123 def _from_fields_and_rank(cls, fields, rank): 1124 """Creates a spec of a StructuredTensor with fields and rank.""" 1125 shape = None 1126 for (k, v) in fields.items(): 1127 field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v) 1128 if field_shape_untruncated is None: 1129 raise ValueError(f'Cannot convert spec of {k}.') 1130 untruncated_rank = field_shape_untruncated.rank 1131 if (untruncated_rank is not None 1132 and untruncated_rank < rank): 1133 raise ValueError( 1134 f'Rank of field {k} is {untruncated_rank}, ' 1135 f'but must be at least {rank}.') 1136 field_shape = field_shape_untruncated._truncate(rank) # pylint: disable=protected-access 1137 if shape is None: 1138 shape = field_shape 1139 else: 1140 shape = shape._merge_with(field_shape) 1141 return StructuredTensor.Spec(_ragged_shape=shape, _fields=fields) 1142 1143 @classmethod 1144 def _from_shape( 1145 cls, shape: dynamic_ragged_shape.DynamicRaggedShape 1146 ) -> 'StructuredTensor.Spec': 1147 """Creates the spec of an empty StructuredTensor.""" 1148 return StructuredTensor.Spec(_ragged_shape=shape, _fields={}) 1149 1150 # For backwards compatibility 1151 @property 1152 def _shape(self) -> tensor_shape.TensorShape: 1153 return self._ragged_shape._to_tensor_shape() # pylint: disable=protected-access 1154 1155 # For backwards compatibility 1156 @property 1157 def _field_specs(self) -> Dict[str, type_spec.TypeSpec]: 1158 return self._fields 1159 1160 # For backwards compatibility 1161 @property 1162 def shape(self) -> tensor_shape.TensorShape: 1163 return self._shape 1164 1165 # For backwards compatibility 1166 @property 1167 def rank(self): 1168 return self._ragged_shape.rank 1169 1170 1171# Regular expression used to determine whether a string is a valid field name. 1172# Note: we plan to relax (or possibly eliminate) this in the future; you 1173# should not rely on the fact that some field names are currently disallowed. 1174_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$') 1175 1176#============================================================================= 1177# Helper funtions 1178#============================================================================= 1179# TODO(edloper): Move some of these helpers to row_partition.py? 1180 1181 1182def _convert_to_structured_field_value(value): 1183 """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor.""" 1184 if isinstance(value, 1185 (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): 1186 return value 1187 elif ragged_tensor.is_ragged(value): 1188 return ragged_tensor.convert_to_tensor_or_ragged_tensor(value) 1189 elif isinstance(value, extension_type.ExtensionType): 1190 return value 1191 else: 1192 try: 1193 return ops.convert_to_tensor(value) 1194 except (ValueError, TypeError) as e: 1195 raise TypeError('Unexpected type for value in `fields`: %r' % 1196 value) from e 1197 1198 1199def _find_shape_dtype(fields, nrows, row_partitions): 1200 """Return a consistent dtype for fields, nrows, & row_partitions.""" 1201 field_dtypes = dict() 1202 for (key, value) in fields.items(): 1203 if isinstance(value, ragged_tensor.RaggedTensor): 1204 field_dtypes[key] = value.row_splits.dtype 1205 elif isinstance(value, StructuredTensor) and value.rank > 0: 1206 field_dtypes[key] = value.nrows().dtype 1207 1208 field_dtype = None 1209 for value in field_dtypes.values(): 1210 if field_dtype is None: 1211 field_dtype = value 1212 elif field_dtype != value: 1213 raise ValueError('field values have incompatible row_partition dtypes. ' + 1214 f'field_dtypes: {field_dtypes}') 1215 1216 row_partition_dtype = None 1217 row_partition_dtypes = [] 1218 if row_partitions is not None: 1219 row_partition_dtypes = [rp.dtype for rp in row_partitions] 1220 for rp_dtype in row_partition_dtypes: 1221 if row_partition_dtype is None: 1222 row_partition_dtype = rp_dtype 1223 elif row_partition_dtype != rp_dtype: 1224 raise ValueError('row_partitions have incompatible dtypes with ' 1225 f'themselves:{row_partition_dtypes}') 1226 1227 nrows_dtype = None 1228 if isinstance(nrows, ops.Tensor): 1229 nrows_dtype = nrows.dtype 1230 all_dtypes = filter(lambda x: x is not None, 1231 [field_dtype, row_partition_dtype, nrows_dtype]) 1232 shape_dtypes = set() 1233 shape_dtypes.update(all_dtypes) 1234 if len(shape_dtypes) > 1: 1235 raise ValueError('row_partition dtypes are inconsistent: ' + 1236 f'field_dtype:{field_dtype} ' + 1237 f'row_partition_dtype:{row_partition_dtype} ' + 1238 f'nrows_dtype:{nrows_dtype}') 1239 elif shape_dtypes: 1240 return shape_dtypes.pop() 1241 else: 1242 return dtypes.int64 1243 1244 1245def _merge_nrows(nrows, static_nrows, value, dtype, validate): 1246 """Merges `nrows` with `nrows(value)`. 1247 1248 Checks that `value` has the expected number of rows (`nrows`), and returns 1249 `nrows`. If `validate` is true, then add validation ops that check that 1250 the `nrows` values match. 1251 1252 Args: 1253 nrows: scalar integer Tensor. 1254 static_nrows: tf.Dimension: static value of nrows, if known. 1255 value: Tensor or RaggedTensor or StructuredTensor 1256 dtype: dtype for `nrows`. 1257 validate: bool -- whether to add validation ops. 1258 1259 Returns: 1260 A tuple `(nrows, static_nrows)`. 1261 """ 1262 static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0) 1263 if isinstance(value, ops.Tensor): 1264 value_nrows = array_ops.shape(value, out_type=dtype)[0] 1265 else: 1266 value_nrows = value.nrows() 1267 if nrows is None: 1268 nrows = value_nrows 1269 elif (static_value_nrows.value is not None and 1270 static_nrows.value is not None): 1271 if not static_value_nrows.is_compatible_with(static_nrows): 1272 raise ValueError('fields have incompatible nrows') 1273 nrows = value_nrows # No need to add an assertion op. 1274 elif validate: 1275 nrows = control_flow_ops.with_dependencies([ 1276 check_ops.assert_equal( 1277 nrows, value_nrows, message='fields have incompatible nrows') 1278 ], nrows) 1279 return nrows, static_nrows._merge_with(static_value_nrows) # pylint: disable=protected-access 1280 1281 1282def _merge_row_partitions(row_partitions, value, rank, dtype, validate): 1283 """Merges `row_partitions` with `row_partitions(value)`.""" 1284 if isinstance(value, ops.Tensor): 1285 value_row_partitions = _row_partitions_for_tensor(value, rank, dtype) 1286 1287 elif isinstance(value, ragged_tensor.RaggedTensor): 1288 value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype) 1289 1290 else: 1291 assert isinstance(value, StructuredTensor), type(value) 1292 value_row_partitions = value.row_partitions[:rank - 1] 1293 1294 assert len(value_row_partitions) == rank - 1 1295 if row_partitions is None: 1296 return tuple(value_row_partitions) 1297 else: 1298 return tuple([ 1299 p1._merge_precomputed_encodings(p2, validate) # pylint: disable=protected-access 1300 for (p1, p2) in zip(row_partitions, value_row_partitions) 1301 ]) 1302 1303 1304def _row_partitions_for_tensor(value, rank, dtype): 1305 """Returns the row partitions for a tf.Tensor.""" 1306 shape = array_ops.shape(value, out_type=dtype) 1307 return _row_partitions_for_uniform_shape(shape, rank) 1308 1309 1310def _row_partitions_for_ragged_tensor(value, rank, dtype): 1311 """Returns the row partitions for a tf.RaggedTensor.""" 1312 assert rank > 1 1313 value_row_partitions = value._nested_row_partitions[:rank - 1] # pylint: disable=protected-access 1314 if len(value_row_partitions) < (rank - 1): 1315 value_row_partitions += _row_partitions_for_tensor( 1316 value.flat_values, rank - len(value_row_partitions), dtype) 1317 assert len(value_row_partitions) == rank - 1 1318 return value_row_partitions 1319 1320 1321def _row_partitions_for_uniform_shape(shape, rank): 1322 """Returns row partitions for the given shape Tensor. 1323 1324 Args: 1325 shape: A vector describing a uniform shape. 1326 rank: The number of dimensions to generate row partitions for 1327 1328 Returns: 1329 A list of (rank-1) `RowPartition`s with uniform row length. 1330 """ 1331 shape_cumprod = math_ops.cumprod(shape[:rank]) 1332 # pylint: disable=g-complex-comprehension 1333 return tuple([ 1334 RowPartition.from_uniform_row_length( 1335 uniform_row_length=shape[i + 1], 1336 nvals=shape_cumprod[i + 1], 1337 nrows=shape_cumprod[i]) for i in range(rank - 1) 1338 ]) 1339 1340 1341def _pyval_field_major_to_node_major(keys, values, depth): 1342 """Regroup each field (k, v) from dict-of-list to list-of-dict. 1343 1344 Given a "field-major" encoding of the StructuredTensor (which maps each key to 1345 a single nested list containing the values for all structs), return a 1346 corresponding "node-major" encoding, consisting of a nested list of dicts. 1347 1348 Args: 1349 keys: The field names (list of string). Must not be empty. 1350 values: The field values (list of python values). Must have the same length 1351 as `keys`. 1352 depth: The list depth at which dictionaries should be created. 1353 1354 Returns: 1355 A nested list of dict, with depth `depth`. 1356 """ 1357 assert keys 1358 if depth == 0: 1359 return dict(zip(keys, values)) 1360 nvals = len(values[0]) 1361 assert all(nvals == len(values[i]) for i in range(1, len(values))) 1362 return [ 1363 _pyval_field_major_to_node_major(keys, value_slice, depth - 1) 1364 for value_slice in zip(*values) 1365 ] 1366 1367 1368def _empty_dict_pylist_from_row_partitions(row_partitions, nrows): 1369 """Returns a python list of empty dicts from the given row partitions. 1370 1371 Args: 1372 row_partitions: The row-partitions describing the ragged shape of the 1373 result. 1374 nrows: The number of rows in the outermost row-partition. (Or if 1375 `len(row_partitions)==0`, then the number of empty dicts to return.) 1376 1377 Returns: 1378 A nested python list whose leaves (if any) are empty python dicts. 1379 """ 1380 if not row_partitions: 1381 return [{} for _ in range(nrows)] 1382 else: 1383 values = _empty_dict_pylist_from_row_partitions( 1384 row_partitions[1:], row_partitions[0].row_splits()[-1]) 1385 splits = row_partitions[0].row_splits() 1386 return [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)] 1387 1388 1389def _pyval_find_struct_keys_and_depth(pyval, keys): 1390 """Finds the keys & depth of nested dictionaries in `pyval`. 1391 1392 Args: 1393 pyval: A nested structure of lists, tuples, and dictionaries. 1394 keys: (output parameter) A set, which will be updated with any keys that are 1395 found in the nested dictionaries. 1396 1397 Returns: 1398 The nesting depth of dictionaries in `pyval`, or `None` if `pyval` does 1399 not contain any dictionaries. 1400 Raises: 1401 ValueError: If dictionaries have inconsistent depth. 1402 """ 1403 if isinstance(pyval, dict): 1404 keys.update(pyval.keys()) 1405 return 0 1406 elif isinstance(pyval, (list, tuple)): 1407 depth = None 1408 for child in pyval: 1409 child_depth = _pyval_find_struct_keys_and_depth(child, keys) 1410 if child_depth is not None: 1411 if depth is None: 1412 depth = child_depth + 1 1413 elif depth != child_depth + 1: 1414 raise ValueError('Inconsistent depth of dictionaries') 1415 return depth 1416 else: 1417 return None 1418 1419 1420def _pyval_update_fields(pyval, fields, depth): 1421 """Append the field values from `pyval` to `fields`. 1422 1423 Args: 1424 pyval: A python `dict`, or nested list/tuple of `dict`, whose value(s) 1425 should be appended to `fields`. 1426 fields: A dictionary mapping string keys to field values. Field values 1427 extracted from `pyval` are appended to this dictionary's values. 1428 depth: The depth at which `pyval` should be appended to the field values. 1429 """ 1430 if not isinstance(pyval, (dict, list, tuple)): 1431 raise ValueError('Expected dict or nested list/tuple of dict') 1432 1433 for (key, target) in fields.items(): 1434 for _ in range(1, depth): 1435 target = target[-1] 1436 target.append(pyval[key] if isinstance(pyval, dict) else []) 1437 1438 if isinstance(pyval, (list, tuple)): 1439 for child in pyval: 1440 _pyval_update_fields(child, fields, depth + 1) 1441 1442 1443def _pyval_empty_list_depth(pyval): 1444 """Find the max depth for nested empty lists. 1445 1446 Args: 1447 pyval: A nested python list. 1448 1449 Returns: 1450 The maximum depth of empty lists in `pyval`, or None if `pyval` contains 1451 anything other than nested empty lists. 1452 """ 1453 if isinstance(pyval, list): 1454 if not pyval: 1455 return 1 1456 depths = [_pyval_empty_list_depth(v) for v in pyval] 1457 if any(depth is None for depth in depths): 1458 return None 1459 else: 1460 return max(depths) + 1 1461 else: 1462 return None 1463 1464 1465def _replace_row_partitions(value, new_partitions): 1466 """Updates `value` to use `new_partitions` as its (outer) row partitions. 1467 1468 This is used to ensure that all fields in a `StructuredTensor` use identical 1469 `RowPartition` objects for the shared dimensions. In particular, 1470 `StructuredTensor.from_fields` first merges all of the row partitions from 1471 any fields, and then replaces the outer row partitions of all fields with 1472 the merged row partitions (using this function). 1473 1474 Args: 1475 value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`. 1476 new_partitions: A list of row-partitions that should be used by `value`. 1477 Must be equivalent to `value`'s current row partitions. 1478 1479 Returns: 1480 A value that is equivalent to `value`, where outer row partitions have been 1481 replaced by `new_partitions`. 1482 """ 1483 if isinstance(value, ops.Tensor) or not new_partitions: 1484 return value 1485 1486 elif isinstance(value, ragged_tensor.RaggedTensor): 1487 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access 1488 values=_replace_row_partitions(value.values, new_partitions[1:]), 1489 row_partition=new_partitions[0]) 1490 1491 else: 1492 assert isinstance(value, StructuredTensor) 1493 new_fields = dict((k, _replace_row_partitions(v, new_partitions)) 1494 for (k, v) in value._fields.items()) 1495 return StructuredTensor._old_init( # pylint: disable=protected-access 1496 fields=new_fields, 1497 shape=value.shape, 1498 nrows=value.nrows(), 1499 row_partitions=tuple(new_partitions) + 1500 tuple(value.row_partitions[len(new_partitions):])) 1501 1502 1503def _partition_outer_dimension(value, row_partition): 1504 """Partitions the outer dimension of `value` using `row_partitions`. 1505 1506 Examples: 1507 1508 >>> partition = RowPartition.from_row_lengths([2, 0, 1]) 1509 >>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition) 1510 <tf.RaggedTensor [[1, 2], [], [3]]> 1511 1512 >>> struct_value = StructuredTensor.from_pyval( 1513 ... [{'x': 1}, {'x': 2}, {'x': 3}]) 1514 >>> _partition_outer_dimension(struct_value, partition) 1515 <StructuredTensor( 1516 fields={ 1517 "x": <tf.RaggedTensor [[1, 2], [], [3]]>}, 1518 shape=(3, None))> 1519 1520 Args: 1521 value: Tensor, RaggedTensor, or StructuredTensor 1522 row_partition: RowPartition 1523 1524 Returns: 1525 A value with the same type as `value`, where 1526 `result.rank = value.rank + 1`. 1527 """ 1528 is_ragged = row_partition.uniform_row_length() is None 1529 if isinstance(value, ops.Tensor) and not is_ragged: 1530 new_shape = array_ops.concat( 1531 [[row_partition.nrows(), 1532 row_partition.uniform_row_length()], 1533 array_ops.shape(value, out_type=row_partition.dtype)[1:]], 1534 axis=0) 1535 return array_ops.reshape(value, new_shape) 1536 elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): 1537 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access 1538 value, row_partition) 1539 else: 1540 assert isinstance(value, StructuredTensor) 1541 nrows = row_partition.static_nrows 1542 ncols = row_partition.static_uniform_row_length 1543 shape = tensor_shape.TensorShape([nrows, 1544 ncols]).concatenate(value.shape[1:]) 1545 fields = dict((k, _partition_outer_dimension(v, row_partition)) 1546 for (k, v) in value._fields.items()) 1547 return StructuredTensor._old_init( # pylint: disable=protected-access 1548 fields, 1549 shape, 1550 row_partition.nrows(), (row_partition,) + value.row_partitions) 1551 1552 1553def _merge_dims(value, outer_axis, inner_axis): 1554 """Merges `outer_axis...inner_axis` of `value` into a single dimension.""" 1555 assert outer_axis < inner_axis 1556 if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): 1557 return ragged_tensor.merge_dims(value, outer_axis, inner_axis) 1558 else: 1559 assert isinstance(value, StructuredTensor) 1560 fields = dict((k, _merge_dims(v, outer_axis, inner_axis)) 1561 for (k, v) in value._fields.items()) 1562 ragged_shape = value._ragged_shape._merge_dims( # pylint: disable=protected-access 1563 outer_axis, inner_axis) 1564 return StructuredTensor(fields, ragged_shape) 1565 1566 1567_structured_tensor_factory_key = object() # unique private object 1568 1569 1570def _dynamic_ragged_shape_spec_from_spec( 1571 spec: Union[dynamic_ragged_shape.DynamicRaggedShape.Spec, 1572 ragged_tensor.RaggedTensorSpec, StructuredTensor.Spec, 1573 tensor_spec.TensorSpec] 1574) -> dynamic_ragged_shape.DynamicRaggedShape.Spec: 1575 if isinstance(spec, StructuredTensor.Spec): 1576 return spec._ragged_shape # pylint: disable=protected-access 1577 else: 1578 return dynamic_ragged_shape.DynamicRaggedShape.Spec._from_spec(spec) # pylint: disable=protected-access 1579 1580 1581def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]: 1582 """FieldName can be given also as string, this normalizes it to a tuple.""" 1583 if isinstance(name, str): 1584 return (name,) 1585 if isinstance(name, list): 1586 return tuple(name) 1587 assert isinstance(name, tuple) 1588 return name 1589 1590 1591def _dicts_to_zeros(pyval): 1592 """Replaces dictionaries zeros in a pylist.""" 1593 if isinstance(pyval, dict): 1594 return 0 1595 return [_dicts_to_zeros(x) for x in pyval] 1596 1597 1598def _merge_dims_generic(source, outer, inner): 1599 """Merges outer_axis...inner_axis into a single dimension. 1600 1601 If outer == inner, this is a NOOP. If inner < outer, then this fials. 1602 If inner >= source.shape.rank, then the behavior is undefined. 1603 1604 Args: 1605 source: a tensor, ragged tensor, or structured tensor. 1606 outer: a python int, indicating the first dimension to compress (must be 1607 nonnegative). 1608 inner: a python int, indicating the first dimension to keep (of the tail) 1609 (must be nonnegative). 1610 1611 Returns: 1612 source with outer_axis...inner_axis merged into a single dimension. 1613 1614 """ 1615 if isinstance(source, StructuredTensor): 1616 return source.merge_dims(outer, inner) 1617 else: 1618 return ragged_tensor.merge_dims(source, outer, inner) 1619 1620 1621def _dynamic_ragged_shape_from_tensor( 1622 field, dtype=None) -> dynamic_ragged_shape.DynamicRaggedShape: 1623 """Extension of DynamicRaggedShape.from_tensor to support StructuredTensor.""" 1624 if isinstance(field, StructuredTensor): 1625 return field._ragged_shape # pylint: disable=protected-access 1626 shape = array_ops.shape_v2(field, out_type=dtype) 1627 1628 if isinstance(shape, ops.Tensor): 1629 return dynamic_ragged_shape.DynamicRaggedShape( 1630 row_partitions=[], 1631 inner_shape=shape) 1632 elif isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape): 1633 return shape 1634 # TODO(martinz): add a test for the following line. 1635 raise TypeError(f'Expected shape tf.shape({field}) to return a Tensor or a ' 1636 f'DynamicRaggedShape. Instead, got: {shape}.') 1637 1638 1639def _merge_with_optional( 1640 a: Optional[dynamic_ragged_shape.DynamicRaggedShape], 1641 b: Optional[dynamic_ragged_shape.DynamicRaggedShape] 1642 ) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]: 1643 if a is None: 1644 return b 1645 if b is None: 1646 return a 1647 return a._merge_with(b) # pylint: disable=protected-access 1648 1649 1650def _shape_from_fields( 1651 fields, rank: int, 1652 dtype: dtypes.DType) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]: 1653 """Given fields, rank, and dtype, create a shape.""" 1654 1655 field_shape = None 1656 for (k, field) in fields.items(): 1657 try: 1658 next_field_shape_raw = _dynamic_ragged_shape_from_tensor( 1659 field, dtype=dtype) 1660 next_field_shape = next_field_shape_raw[:rank] 1661 field_shape = _merge_with_optional(field_shape, next_field_shape) 1662 except Exception as err: 1663 raise ValueError(f'Error in shape of {k}') from err 1664 1665 return field_shape 1666 1667 1668# pylint:disable=protected-access 1669def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions): 1670 """Produce a DynamicRaggedShape for StructuredTensor.""" 1671 assert isinstance(fields, dict), fields 1672 assert isinstance(shape, tensor_shape.TensorShape), shape 1673 assert nrows is None or isinstance(nrows, ops.Tensor) or isinstance( 1674 nrows, int), nrows 1675 assert row_partitions is None or isinstance(row_partitions, 1676 tuple), row_partitions 1677 rank = shape.rank 1678 1679 if rank is None: 1680 raise TypeError("StructuredTensor's shape must have known rank.") 1681 1682 # TODO(martinz): figure out whether to validate. 1683 dtype = _find_shape_dtype(fields, nrows, row_partitions) 1684 result = None 1685 if shape.is_fully_defined(): 1686 result = dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( 1687 shape.as_list(), dtype=dtype) 1688 1689 if rank == 0: 1690 return dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( 1691 array_ops.zeros((0,), dtype=dtype)) 1692 1693 result = _merge_with_optional(result, _shape_from_fields(fields, rank, dtype)) 1694 if rank == 1: 1695 alt_value = tensor_shape.dimension_value(shape[0]) 1696 if alt_value is not None: 1697 nrows = alt_value 1698 if nrows is not None: 1699 result = _merge_with_optional( 1700 result, 1701 dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( 1702 [nrows], dtype=dtype)) 1703 if result is None: 1704 raise ValueError('Must specify `nrows`, a fully specified `shape`,' + 1705 ' or have `fields` if `rank=1`') 1706 1707 return result 1708 1709 if row_partitions: 1710 result = _merge_with_optional( 1711 result, dynamic_ragged_shape.DynamicRaggedShape.from_row_partitions( 1712 row_partitions, dtype=dtype)) 1713 1714 if result is None: 1715 raise ValueError('Must specify row_partitions, a fully specified shape, ' + 1716 'or have fields if rank > 1') 1717 return result 1718 1719 1720# TODO(martinz): Drop this method or rename. 1721def StructuredTensorSpec(shape, field_specs): # pylint:disable=invalid-name 1722 """A placeholder for the old StructuredTensorSpec.""" 1723 if not isinstance(field_specs, dict): 1724 raise TypeError('field_specs must be a dictionary.') 1725 for k in field_specs.keys(): 1726 if not isinstance(k, str): 1727 raise TypeError('field_specs must be a dictionary with string keys.') 1728 for v in field_specs.values(): 1729 if not isinstance(v, type_spec.TypeSpec): 1730 raise TypeError('field_specs must be a dictionary with TypeSpec values.') 1731 1732 shape = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 1733 tensor_shape.as_shape(shape), 1734 0, 1735 dtypes.int32) 1736 rank = shape.rank 1737 if rank is None: 1738 raise TypeError("StructuredTensor's shape must have known rank.") 1739 for (k, v) in field_specs.items(): 1740 field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v) 1741 if field_shape_untruncated is None: 1742 raise ValueError(f'Cannot convert spec of {k}.') 1743 untruncated_rank = field_shape_untruncated.rank 1744 if (untruncated_rank is not None 1745 and untruncated_rank < rank): 1746 raise ValueError( 1747 f'Rank of field {k} is {untruncated_rank},' 1748 f' but must be at least {rank}.') 1749 field_shape = field_shape_untruncated._truncate(rank) 1750 shape = shape._merge_with(field_shape) 1751 return StructuredTensor.Spec(_ragged_shape=shape, _fields=field_specs) 1752