xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/structured/structured_tensor.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Structured Tensors."""
16
17import re
18from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
19
20import numpy as np
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import extension_type
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import type_spec
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import check_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops.ragged import dynamic_ragged_shape
34from tensorflow.python.ops.ragged import ragged_factory_ops
35from tensorflow.python.ops.ragged import ragged_tensor
36from tensorflow.python.ops.ragged.row_partition import RowPartition
37from tensorflow.python.util import compat
38from tensorflow.python.util import nest
39
40
41class StructuredTensor(extension_type.BatchableExtensionType):
42  """A multidimensional collection of structures with the same schema.
43
44  A **`StructuredTensor`** is a multi-dimensional collection of ***structures***
45  with the same ***schema***, where:
46
47  * A ***schema*** is a collection of fields, each of which has a name and type.
48  * A ***structure*** maps each field in the schema to a tensor value (which
49    could be a nested StructuredTensor).
50
51  As an important special case, a 1D `StructuredTensor` encodes a 2D table,
52  where columns are heterogeneous `Tensor`s, and rows are the aligned elements
53  in each of those `Tensor`s.
54
55  Internally, StructuredTensors use a "field-major" encoding: for each leaf
56  field, there is a single tensor that stores the value of that field for all
57  structures in the `StructuredTensor`.
58
59  ### Examples
60
61  >>> # A scalar StructuredTensor describing a single person.
62  >>> s1 = StructuredTensor.from_pyval(
63  ...     {"age": 82, "nicknames": ["Bob", "Bobby"]})
64  >>> s1.shape
65  TensorShape([])
66  >>> s1["age"]
67  <tf.Tensor: shape=(), dtype=int32, numpy=82>
68
69  >>> # A vector StructuredTensor describing three people.
70  >>> s2 = StructuredTensor.from_pyval([
71  ...     {"age": 12, "nicknames": ["Josaphine"]},
72  ...     {"age": 82, "nicknames": ["Bob", "Bobby"]},
73  ...     {"age": 42, "nicknames": ["Elmo"]}])
74  >>> s2.shape
75  TensorShape([3])
76  >>> s2[0]["age"]
77  <tf.Tensor: shape=(), dtype=int32, numpy=12>
78
79
80  ### Field Paths
81
82  A *field path* is a tuple of field names, specifying the path to a nested
83  field.
84  """
85  _fields: Mapping[str, Union[ops.Tensor, ragged_tensor.RaggedTensor,
86                              'StructuredTensor', extension_type.ExtensionType]]
87  _ragged_shape: dynamic_ragged_shape.DynamicRaggedShape
88
89  __name__ = 'tf.StructuredTensor'
90  #=============================================================================
91  # Common Types
92  #=============================================================================
93  # pylint: disable=invalid-name
94  # Field names work as key, and they can be a sequence to refer to the
95  # sub-levels (embedded) StructuredTensor's.
96  FieldName = Union[str, Sequence[str]]
97
98  # Each field may contain one of the following types of Tensors.
99  FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor']
100
101  # Function that takes a FieldValue as input and returns the transformed
102  # FieldValue.
103  FieldFn = Callable[[FieldValue], FieldValue]
104
105  # pylint: enable=invalid-name
106
107  #=============================================================================
108  # Constructor & Factory Methods
109  #=============================================================================
110  def __init__(self, fields: Mapping[str, FieldValue],
111               ragged_shape: dynamic_ragged_shape.DynamicRaggedShape):
112    self._fields = fields
113    self._ragged_shape = ragged_shape
114
115  @classmethod
116  def _old_init(cls, fields, shape, nrows, row_partitions, internal=False):
117    """Private constructor -- use factory methods to create StructuredTensors.
118
119    This constructor builds a `StructuredTensor` from the given attributes,
120    performing minimal validation.
121
122    Args:
123      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
124        `StructuredTensor`.  (This dict is not copied, so the caller must ensure
125        that it does not get mutated via leaked references.)
126      shape: `tf.TensorShape` with statically known rank.
127      nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`.
128      row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`.
129      internal: ignored argument.
130    Returns:
131      a StructuredTensor.
132    """
133    assert isinstance(fields, dict), fields
134    assert isinstance(shape, tensor_shape.TensorShape), shape
135    assert nrows is None or isinstance(nrows, ops.Tensor), nrows
136    assert row_partitions is None or isinstance(row_partitions,
137                                                tuple), row_partitions
138    return StructuredTensor(
139        fields=fields,
140        ragged_shape=_dynamic_ragged_shape_init(fields, shape, nrows,
141                                                row_partitions))
142
143  @classmethod
144  def from_shape(
145      cls, ragged_shape: dynamic_ragged_shape.DynamicRaggedShape
146  ) -> 'StructuredTensor':
147    """Creates a `StructuredTensor` with no fields and ragged_shape.
148
149    Args:
150      ragged_shape: the shape of the structured tensor.
151
152    Returns:
153      a StructuredTensor with no fields and ragged_shape.
154    """
155    return StructuredTensor(fields={}, ragged_shape=ragged_shape)
156
157  @classmethod
158  def from_fields(cls,
159                  fields,
160                  shape=(),
161                  nrows=None,
162                  row_partitions=None,
163                  validate=False):
164    """Creates a `StructuredTensor` from a dictionary of fields.
165
166    Args:
167      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
168        `StructuredTensor`, providing the values for individual fields in each
169        structure.  If `shape.rank > 0`, then every tensor in `fields` must have
170        the same shape in the first `shape.rank` dimensions; and that shape must
171        be compatible with `shape`; and `result[i1...iN][key] =
172        fields[key][i1...iN]` (where `N==shape.rank`).
173      shape: A `TensorShape`: static information about the shape of the
174        `StructuredTensor`.  Must have a known `rank`.  Defaults to scalar shape
175        (i.e. `rank=0`).
176      nrows: scalar integer tensor containing the number of rows in this
177        `StructuredTensor`.  Should only be specified if `shape.rank > 0`.
178        Default value is inferred from the `fields` values.  If `fields` is
179        empty, then this must be specified.
180      row_partitions: A list of `RowPartition`s describing the (possibly ragged)
181        shape of this `StructuredTensor`.  Should only be specified if
182        `shape.rank > 1`.  Default value is inferred from the `fields` values.
183        If `fields` is empty, then this must be specified.
184      validate: If true, then add runtime validation ops that check that the
185        field values all have compatible shapes in the outer `shape.rank`
186        dimensions.
187
188    Returns:
189      A `StructuredTensor`.
190
191    Examples:
192
193      >>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]})
194      <StructuredTensor(
195        fields={
196          "x": tf.Tensor(1, shape=(), dtype=int32),
197          "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)},
198        shape=())>
199
200      >>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]},
201      ...                              shape=[2])
202      <StructuredTensor(
203        fields={
204          "bar": tf.Tensor([3 4], shape=(2,), dtype=int32),
205          "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)},
206        shape=(2,))>
207    """
208    shape = tensor_shape.as_shape(shape)
209    rank = shape.rank
210    if rank is None:
211      raise ValueError("StructuredTensor's shape must have known rank.")
212    if not isinstance(fields, dict):
213      raise TypeError('fields must be a dictionary, got %s' %
214                      type(fields).__name__)
215    if rank < 2 and row_partitions:
216      raise ValueError('row_partitions must be None or [] if shape.rank<2')
217    if rank == 0 and nrows is not None:
218      raise ValueError('nrows must be None if shape.rank==0')
219    if row_partitions is not None:
220      row_partitions = tuple(row_partitions)
221      if len(row_partitions) != max(0, rank - 1):
222        raise ValueError('len(row_partitions) must be shape.rank-1')
223    elif rank < 2:
224      row_partitions = ()
225
226    fields = dict(fields)  # Make a private copy.
227    with ops.name_scope(None, 'StructuredTensor', fields.values()):
228      # TODO(martinz): Make this have better errors.
229      shape = _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions)
230
231      # TODO(martinz): This may not need to be done if all fields are dense.
232      if shape.rank > 1:
233        shape = shape._with_num_row_partitions(shape.rank - 1)
234
235      # Validate keys and convert field values to tensors.
236      for key, value in fields.items():
237        if not isinstance(key, str):
238
239          raise TypeError(
240              f'Unexpected type for key in `fields`: {key}')
241        if not _FIELD_NAME_RE.match(key):
242          raise ValueError('Field name %r is not currently allowed.' % key)
243        fields[key] = _convert_to_structured_field_value(value)
244
245        fields = dict([(k, _replace_row_partitions(v, row_partitions))
246                       for (k, v) in fields.items()])
247      return cls(fields=fields, ragged_shape=shape)
248
249  @classmethod
250  def from_fields_and_rank(cls, fields, rank, validate=False):
251    """Creates a `StructuredTensor` from a nonempty dictionary of fields.
252
253    Args:
254      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
255        `StructuredTensor`, providing the values for individual fields in each
256        structure.  If `rank > 0`, then every tensor in `fields` must have
257        the same shape in the first `rank` dimensions. Cannot be empty.
258      rank: The rank of the resulting structured tensor.
259      validate: If true, then add runtime validation ops that check that the
260        field values all have compatible shapes in the outer `rank`
261        dimensions.
262
263    Returns:
264      A `StructuredTensor`.
265    Examples:
266      >>> StructuredTensor.from_fields_and_rank({'x': 1, 'y': [1, 2, 3]}, 0)
267      <StructuredTensor(
268        fields={
269          "x": tf.Tensor(1, shape=(), dtype=int32),
270          "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)},
271        shape=())>
272      >>> StructuredTensor.from_fields_and_rank({'foo': [1, 2], 'bar': [3, 4]},
273      ...                              1)
274      <StructuredTensor(
275        fields={
276          "bar": tf.Tensor([3 4], shape=(2,), dtype=int32),
277          "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)},
278        shape=(2,))>
279    """
280    if not fields:
281      raise ValueError('Must provide at least one field')
282    if not isinstance(rank, int):
283      raise ValueError('rank must be an integer')
284    if rank < 0:
285      raise ValueError('rank must be nonnegative')
286    fields = {
287        k: _convert_to_structured_field_value(v) for (k, v) in fields.items()
288    }
289    dtype = _find_shape_dtype(fields, None, None)
290
291    shape = _shape_from_fields(fields, rank, dtype)
292    if rank > 1:
293      shape = shape._with_num_row_partitions(rank - 1)
294    new_rp = shape._row_partitions  # pylint: disable=protected-access
295    fields = {
296        k: _replace_row_partitions(v, new_rp) for (k, v) in fields.items()
297    }
298    return StructuredTensor(fields=fields, ragged_shape=shape)
299
300  def with_updates(
301      self,
302      updates: Dict[FieldName, Union[FieldValue, FieldFn, None]],
303      validate: bool = False
304  ) -> 'StructuredTensor':
305    """Creates a new `StructuredTensor` with the updated fields.
306
307    If this `StructuredTensor` is a scalar, and `k` is the `FieldName` being
308    updated and `v` the new value, then:
309
310    ```
311    result[k] = v              # If (k, v) is in updates and v is a FieldValue
312    result[k] = f(self[k])     # If (k, f) is in updates and f is a FieldFn
313    result[k] = self[k]        # If k is in self.field_names but not in updates
314    ```
315
316    If this `StructuredTensor` has rank `N` and shape `[D1...DN]`, then each
317    FieldValue `v` in `updates` must have shape `[D1...DN, ...]`, that is,
318    prefixed with the same shape as the `StructuredTensor`. Then the resulting
319    `StructuredTensor` will have:
320
321    ```
322    result[i1...iN][k] = v[i1...iN]                        # (k, v) in updates
323    result[i1...iN][k] = f(self.field_value(k))[i1...iN]   # (k, f) in updates
324    result[i1...iN][k] = self[i1...iN][k]                  # k not in updates
325    ```
326
327    Note that `result.shape` is always equal to `self.shape` (but the shapes
328    of nested StructuredTensors may be changed if they are updated with new
329    values).
330
331    Args:
332      updates: A dictionary mapping `FieldName` to either a `FieldValue` to be
333        used to update, or a `FieldFn` that will transform the value for the
334        given `FieldName`. `FieldName` can be a string for a direct field, or a
335        sequence of strings to refer to a nested sub-field. `FieldFn` is a
336        function that takes a `FieldValue` as input and should return a
337        `FieldValue`. All other fields are copied over to the new
338        `StructuredTensor`. New `FieldName` can be given (to add new fields),
339        but only to existing `StructuredTensor`, it won't automatically create
340        new nested structures -- but one can create a whole `StructureTensor`
341        sub-structure and set that into an existing structure. If the new value
342        is set to `None`, it is removed.
343      validate: If true, then add runtime validation ops that check that the
344        field values all have compatible shapes in the outer `shape.rank`
345        dimensions.
346
347    Returns:
348      A `StructuredTensor`.
349
350    Raises:
351      `ValueError`: If the any of the `FieldName` keys points to non-existent
352        sub-structures, if parent and child nodes are updated, if shapes
353        change, if a delete update is given for a non-existant field, or if a
354        `FieldFn` transforming function is given for a `FieldName` that doesn't
355        yet exist.
356
357    Examples:
358
359    >>> shoes_us = StructuredTensor.from_pyval([
360    ...    {"age": 12, "nicknames": ["Josaphine"],
361    ...       "shoes": {"sizes": [8.0, 7.5, 7.5]}},
362    ...    {"age": 82, "nicknames": ["Bob", "Bobby"],
363    ...        "shoes": {"sizes": [11.0, 11.5, 12.0]}},
364    ...    {"age": 42, "nicknames": ["Elmo"],
365    ...        "shoes": {"sizes": [9.0, 9.5, 10.0]}}])
366    >>> def us_to_europe(t):
367    ...   return tf.round(t * 2.54 + 17.0)  # Rough approximation.
368    >>> shoe_sizes_key = ("shoes", "sizes")
369    >>> shoes_eu = shoes_us.with_updates({shoe_sizes_key: us_to_europe})
370    >>> shoes_eu.field_value(shoe_sizes_key)
371    <tf.RaggedTensor [[37.0, 36.0, 36.0], [45.0, 46.0, 47.0],
372    [40.0, 41.0, 42.0]]>
373    """
374    updates_items = [(_normalize_field_name_to_tuple(name), value)
375                     for name, value in updates.items()]
376
377    # Sort by keys and check for updates of both parent and child nodes.
378    updates_items = sorted(updates_items)
379    for i in range(1, len(updates_items)):
380      # Parent of a node would precede node in the sorted order.
381      name = updates_items[i][0]  # item[0] is the name, item[1] is the value.
382      prev_name = updates_items[i - 1][0]
383      if name[:len(prev_name)] == prev_name:
384        raise ValueError(
385            '`StructuredTensor.with_updates` does not allow both parent and '
386            'child nodes to be updated: parent={}, child={}. If needed you can '
387            'update child nodes in the parent update value.'.format(
388                prev_name, name))
389    return self._with_updates_impl((), updates_items, validate)
390
391  def _with_updates_impl(
392      self,
393      error_prefix: Tuple[str],
394      updates: List[Tuple[FieldName, Union[FieldValue, FieldFn]]],
395      validate: bool) -> 'StructuredTensor':
396    """Recursive part of `with_updates` implementation."""
397    # Get current fields.
398    new_fields = dict(self._fields)
399
400    # Convert field name to string with full path for error messages.
401    def name_fullpath(name: Sequence[str]) -> str:
402      return str(error_prefix + (name,))
403
404    # Apply value if a function or the value itself.
405    def apply_value(name: str, value: Union['FieldValue',
406                                            'FieldFn']) -> 'FieldValue':
407      if callable(value):
408        # `value` is actually a transforming function.
409        if name not in new_fields:
410          raise ValueError(
411              '`StructuredTensor.with_updates` cannot update the field {} '
412              'because a transforming function was given, but that field '
413              'does not already exist.'.format(name_fullpath(name)))
414        value = value(new_fields[name])
415      return value
416
417    # Merge updates.
418    for name, value in updates:
419      if not name or not name[0]:
420        raise ValueError(
421            '`StructuredTensor.with_updates` does not allow empty names '
422            '{}.'.format(name_fullpath(name)))
423
424      if len(name) == 1:
425        name = name[0]
426        if value is None:
427          if name not in new_fields:
428            raise ValueError(
429                '`StructuredTensor.with_updates` cannot delete field '
430                '{} because it is not present.'.format(name_fullpath(name)))
431          new_fields.pop(name)
432        else:
433          new_fields[name] = apply_value(name, value)
434      else:
435        # Recursive
436        prefix = name[0]
437        suffix = name[1:]
438        if prefix not in new_fields:
439          raise ValueError(
440              '`StructuredTensor.with_updates` cannot create new sub-field '
441              '{} if parent field {} is not set.'.format(
442                  error_prefix + tuple(name), name_fullpath(prefix)))
443        current_value = new_fields[prefix]
444        if not isinstance(current_value, StructuredTensor):
445          raise ValueError(
446              '`StructuredTensor.with_updates` cannot create new sub-field '
447              '{} if parent structure {} is not a `StructuredTensor` that '
448              'can contain sub-structures -- it is a `{}`.'.format(
449                  error_prefix + tuple(name), name_fullpath(prefix),
450                  type(current_value)))
451        one_update = [(suffix, value)]
452
453        # Accessing protected member in recursion.
454        # FutureWork: optimize by aggregating the recursions, instead of
455        #   calling one at a time.
456        # pylint: disable=protected-access
457        value = current_value._with_updates_impl(error_prefix + (prefix,),
458                                                 one_update, validate)
459        # pylint: enable=protected-access
460        new_fields[prefix] = value
461
462    # TODO(edloper): When validate=True, only validate the modified fields.
463    try:
464      return StructuredTensor.from_fields(
465          new_fields,
466          shape=self.shape,
467          row_partitions=self.row_partitions,
468          nrows=self.nrows(),
469          validate=validate)
470
471    except ValueError as e:
472      msg = '`StructuredTensor.with_updates` failed'
473      if error_prefix:
474        msg = '{} for field {}'.format(msg, error_prefix)
475      raise ValueError(msg) from e
476
477  def _promote_helper(self, source_path, new_parent_path):
478    """Creates a promoted field without adding it to the structure.
479
480    Args:
481      source_path: the source path in the structured tensor.
482      new_parent_path: the new parent path. Must be a prefix of source_path.
483
484    Returns:
485      a composite tensor of source_path promoted.
486    Raises:
487      ValueError: if the shape of the field is unknown and the right strategy
488      cannot be determined.
489    """
490    current_field = self.field_value(source_path)
491    new_parent_rank = self.field_value(new_parent_path).rank
492    parent_rank = self.field_value(source_path[:-1]).rank
493    if new_parent_rank == parent_rank:
494      return current_field
495    current_field_rank = current_field.shape.rank
496    if current_field_rank is None:
497      raise ValueError('Cannot determine if dimensions should be merged.')
498    inner_dim = min(parent_rank, current_field_rank - 1)
499    if inner_dim <= new_parent_rank:
500      return current_field
501    return _merge_dims_generic(current_field, new_parent_rank, inner_dim)
502
503  def promote(self, source_path, new_name):
504    """Promotes a field, merging dimensions between grandparent and parent.
505
506    >>> d = [
507    ...  {'docs': [{'tokens':[1, 2]}, {'tokens':[3]}]},
508    ...  {'docs': [{'tokens':[7]}]}]
509    >>> st = StructuredTensor.from_pyval(d)
510    >>> st2 =st.promote(('docs','tokens'), 'docs_tokens')
511    >>> st2[0]['docs_tokens']
512    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
513    >>> st2[1]['docs_tokens']
514    <tf.Tensor: shape=(1,), dtype=int32, numpy=array([7], dtype=int32)>
515
516    Args:
517      source_path: the path of the field or substructure to promote; must have
518        length at least 2.
519      new_name: the name of the new field (must be a string).
520
521    Returns:
522      a modified structured tensor with the new field as a child of the
523      grandparent of the source_path.
524
525    Raises:
526      ValueError: if source_path is not a list or a tuple or has a length
527        less than two, or new_name is not a string, or the rank
528        of source_path is unknown and it is needed.
529    """
530    if not isinstance(new_name, str):
531      raise ValueError('new_name is not a string')
532    if not isinstance(source_path, (list, tuple)):
533      raise ValueError('source_path must be a list or tuple')
534
535    if len(source_path) < 2:
536      raise ValueError('source_path must have length at least two')
537
538    grandparent_path = source_path[:-2]
539    new_field = self._promote_helper(source_path, grandparent_path)
540    new_path = grandparent_path + (new_name,)
541    return self.with_updates({new_path: new_field})
542
543  #=============================================================================
544  # Properties
545  #=============================================================================
546
547  @property
548  def rank(self):
549    """The rank of this StructuredTensor.  Guaranteed not to be `None`."""
550    return self._ragged_shape.rank
551
552  @property
553  def shape(self):
554    """The static shape of this StructuredTensor.
555
556    The returned `TensorShape` is guaranteed to have a known rank, but the
557    individual dimension sizes may be unknown.
558
559    Returns:
560      `tf.TensorShape`
561    """
562    return self._ragged_shape._to_tensor_shape()  # pylint: disable=protected-access
563
564  # TODO(martinz): for backwards compatibility
565  @property
566  def _row_partitions(self):
567    """Deprecated form of row_partitions."""
568    return self.row_partitions
569
570  # TODO(edloper): Make this a func instead of a property?  Or make nrows
571  # a property instead of a func?  Seems like these should be consistent.
572  @property
573  def row_partitions(self):
574    """A tuple of `RowPartition`s defining the shape of this `StructuredTensor`.
575
576    When `self.rank <= 1`, this tuple will be empty.
577
578    When `self.rank > 1`, these `RowPartitions` define the shape of the
579    `StructuredTensor` by describing how a flat (1D) list of structures can be
580    repeatedly partitioned to form a higher-dimensional object.  In particular,
581    the flat list is first partitioned into sublists using `row_partitions[-1]`,
582    and then those sublists are further partitioned using `row_partitions[-2]`,
583    etc.  The following examples show the row partitions used to describe
584    several different `StructuredTensor`, each of which contains 8 copies of
585    the same structure (`x`):
586
587    >>> x = {'a': 1, 'b': ['foo', 'bar', 'baz']}       # shape = [] (scalar)
588
589    >>> s1 = [[x, x, x, x], [x, x, x, x]]              # shape = [2, 4]
590    >>> StructuredTensor.from_pyval(s1).row_partitions
591    (tf.RowPartition(row_splits=[0 4 8]),)
592
593    >>> s2 = [[x, x], [x, x], [x, x], [x, x]]          # shape = [4, 2]
594    >>> StructuredTensor.from_pyval(s2).row_partitions
595    (tf.RowPartition(row_splits=[0 2 4 6 8]),)
596
597    >>> s3 = [[x, x, x], [], [x, x, x, x], [x]]        # shape = [2, None]
598    >>> StructuredTensor.from_pyval(s3).row_partitions
599    (tf.RowPartition(row_splits=[0 3 3 7 8]),)
600
601    >>> s4 = [[[x, x], [x, x]], [[x, x], [x, x]]]      # shape = [2, 2, 2]
602    >>> StructuredTensor.from_pyval(s4).row_partitions
603    (tf.RowPartition(row_splits=[0 2 4]),
604     tf.RowPartition(row_splits=[0 2 4 6 8]))
605
606
607    >>> s5 = [[[x, x], [x]], [[x, x]], [[x, x], [x]]]  # shape = [3, None, None]
608    >>> StructuredTensor.from_pyval(s5).row_partitions
609    (tf.RowPartition(row_splits=[0 2 3 5]),
610     tf.RowPartition(row_splits=[0 2 3 5 7 8]))
611
612    Note that shapes for nested fields (such as `x['b']` in the above example)
613    are not considered part of the shape of a `StructuredTensor`, and are not
614    included in `row_partitions`.
615
616    If this `StructuredTensor` has a ragged shape (i.e., if any of the
617    `row_partitions` is not uniform in size), then all fields will be encoded
618    as either `RaggedTensor`s or `StructuredTensor`s with these `RowPartition`s
619    used to define their outermost `self.rank` dimensions.
620
621    Returns:
622      A `tuple` of `RowPartition` objects with length `self.rank - 1`
623      (or `0` if `self.rank < 2`)
624
625    """
626    if self.rank < 2:
627      return ()
628    return self._ragged_shape._as_row_partitions()  # pylint:disable=protected-access
629
630  def nrows(self):
631    """The number of rows in this StructuredTensor (if rank>0).
632
633    This means the length of the outer-most dimension of the StructuredTensor.
634
635    Notice that if `self.rank > 1`, then this equals the number of rows
636    of the first row partition. That is,
637    `self.nrows() == self.row_partitions[0].nrows()`.
638
639    Otherwise `self.nrows()` will be the first dimension of the field values.
640
641    Returns:
642      A scalar integer `Tensor` (or `None` if `self.rank == 0`).
643    """
644    if self.rank == 0:
645      return None
646    return self._ragged_shape[0]
647
648  def _is_eager(self):
649    """True if all fields are composed of eager tensors."""
650    tensors = nest.flatten(self, expand_composites=True)
651    return all(isinstance(t, ops.EagerTensor) for t in tensors)
652
653  #=============================================================================
654  # Encoding
655  #=============================================================================
656
657  def field_names(self):
658    """Returns the string field names for this `StructuredTensor`."""
659    return tuple(self._fields.keys())
660
661  def field_value(self, field_name):
662    """Returns the tensor value for the specified field or path.
663
664    If `field_name` is a `string`, then it names a field directly owned by this
665    `StructuredTensor`.  If this `StructuredTensor` has shape `[D1...DN]`, then
666    the returned tensor will have shape `[D1...DN, V1...VM]`, where the slice
667    `result[d1...dN]` contains the field value for the structure at
668    `self[d1...dN]`.
669
670    If `field_name` is a `tuple` of `string`, then it specifies a path to a
671    field owned by nested `StructuredTensor`.  In particular,
672    `struct.field_value((f1, f2, ..., fN))` is equivalent to
673    `struct.field_value(f1).field_value(f2)....field_value(fN)`
674
675    Args:
676      field_name: `string` or `tuple` of `string`: The field whose values should
677        be returned.
678
679    Returns:
680      `Tensor`, `StructuredTensor`, or `RaggedTensor`.
681
682    Raises:
683      KeyError: If the given field_name is not found.
684    """
685    if isinstance(field_name, (list, tuple)):
686      value = self
687      for f in field_name:
688        if not isinstance(value, StructuredTensor):
689          raise KeyError('Field path {} not found in {}'.format(
690              field_name, self))
691        value = value.field_value(f)
692      return value
693    return self._fields[field_name]
694
695  #=============================================================================
696  # Operators
697  #=============================================================================
698
699  # TODO(edloper): Add support for ellipsis and/or newaxis?
700  def __getitem__(self, key):
701    """Returns the specified piece of this StructuredTensor.
702
703    * If `struct_tensor` is scalar (i.e., a single structure), then
704      `struct_tensor[f]` returns the value of field `f` (where `f` must be a
705      string).
706
707    * If `struct_tensor` is non-scalar (i.e., a vector or higher-dimensional
708      tensor of structures), `struct_tensor[i]` selects an element or slice of
709      the tensor using standard Python semantics (e.g., negative values index
710      from the end).  `i` may have any of the following types:
711
712      * `int` constant
713      * `string` constant
714      * scalar integer `Tensor`
715      * `slice` containing integer constants and/or scalar integer
716        `Tensor`s
717
718    #### Multidimensional indexing
719
720    `StructuredTensor` supports multidimensional indexing.  I.e., `key` may be a
721    `tuple` of values, indexing or slicing multiple dimensions at once.  For
722    example, if `people` is a vector of structures, each of which has a vector-
723    valued `names` field, then `people[3, 'names', 0]` is equivalent to
724    `people[3]['names'][0]`; and `people[:, 'names', :]` will return a (possibly
725    ragged) matrix of names, with shape `[num_people, num_names_per_person]`.
726
727    Args:
728      key: Indicates which piece of the StructuredTensor to return.
729
730    Returns:
731      A `Tensor`, `StructuredTensor`, or `RaggedTensor`.
732    """
733    if isinstance(key, list):
734      key = tuple(key)
735    elif not isinstance(key, tuple):
736      key = (key,)
737    if not key:
738      return self
739
740    if self.rank == 0:
741      return self._scalar_getitem(key)
742    else:
743      return self._tensor_getitem(key)
744
745  def _scalar_getitem(self, key):
746    if (isinstance(key[0], slice) and key[0].start is None and
747        key[0].stop is None and key[0].step is None):
748      fields = dict((field_name, field_value.__getitem__(key[1:]))
749                    for (field_name, field_value) in self._fields.items())
750      return StructuredTensor.from_fields(fields, self.shape)
751
752    elif not isinstance(key[0], compat.bytes_or_text_types):
753      raise ValueError('Key for indexing a StructuredTensor must be a '
754                       "string or a full slice (':')")
755
756    return self._fields[key[0]].__getitem__(key[1:])
757
758  def _tensor_getitem(self, key):
759    rank = self.rank
760    if len(key) <= rank:
761      new_fields = dict((field_name, field_value.__getitem__(key))
762                        for (field_name, field_value) in self._fields.items())
763      result_shape = self.shape.as_list()
764      for d, k in enumerate(key):
765        if isinstance(k, slice):
766          if not (k.start is None and k.stop is None and k.step is None):
767            # TODO(edloper): Better static shape analysis here.
768            result_shape[d] = None
769        elif isinstance(k, (int, ops.Tensor)):
770          result_shape[d] = -1  # mark for deletion
771        elif k is None:
772          raise ValueError('Slicing not supported for tf.newaxis')
773        else:
774          # Ellipsis, tf.newaxis:
775          raise ValueError('Slicing not supported for %r' % k)
776      result_shape = [d for d in result_shape if d != -1]
777      return StructuredTensor.from_fields(new_fields, result_shape)
778
779    else:
780      if not isinstance(key[rank], compat.bytes_or_text_types):
781        # TODO(edloper): Also support full slice here?
782        raise ValueError('Key for indexing a StructuredTensor must be a string')
783      return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:])
784
785  def __repr__(self):
786    fields = sorted(self._fields.items())
787    fields = ((k, str(v).replace('\n', '\n            ')) for k, v in fields)
788    fields = ('"{}": {}'.format(k, v) for k, v in fields)
789    dict_repr = ',\n        '.join(fields)
790    return ('<StructuredTensor(\n'
791            '    fields={\n'
792            '        %s},\n'
793            '    shape=%s)>' % (dict_repr, self.shape))
794
795  #=============================================================================
796  # Conversion
797  #=============================================================================
798
799  def to_pyval(self):
800    """Returns this StructuredTensor as a nested Python dict or list of dicts.
801
802    Converts this `StructuredTensor` to a nested python value:
803
804    * `StructTensors` with `rank=0` are converted into a dictionary, with an
805      entry for each field.  Field names are used as keys and field values are
806      converted to python values.  In particular:
807
808      * Scalar Tensor fields are converted to simple values (such as
809        `int` or `float` or `string`)
810      * Non-scalar Tensor fields and RaggedTensor fields are converted to
811        nested lists of simple values.
812      * StructuredTensor fields are converted recursively using `to_pyval`.
813
814    * `StructTensors` with `rank>0` are converted to nested python `list`s,
815      containing one dictionary for each structure (where each structure's
816      dictionary is defined as described above).
817
818    Requires that all fields are Eager tensors.
819
820    >>> StructuredTensor.from_fields(
821    ...     {'a': [1, 2, 3]}, [3]).to_pyval()
822    [{'a': 1}, {'a': 2}, {'a': 3}]
823
824    Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
825
826    Returns:
827      A nested Python dict or list of dicts.
828    """
829    if not self._is_eager():
830      raise ValueError(
831          'StructuredTensor.to_pyval() is only supported in eager mode.')
832
833    # Convert each field value to a nested list.
834    result = {}
835    for (key, value) in self._fields.items():
836      if isinstance(value, ops.EagerTensor):
837        value = value.numpy()
838      if isinstance(value, np.ndarray):
839        value = value.tolist()
840      elif isinstance(value, ragged_tensor.RaggedTensor):
841        value = value.to_list()
842      elif isinstance(value, StructuredTensor):
843        value = value.to_pyval()
844      # TODO(edloper): Throw an exception if value is an unexpected type.
845      result[key] = value
846
847    # If rank>0, then re-group each value from dict-of-list to list-of-dict.
848    if len(self.shape) > 0:  # pylint: disable=g-explicit-length-test
849      if not result:  # special-case for StructuredTensors w/ no fields.
850        return _empty_dict_pylist_from_row_partitions(self.row_partitions,
851                                                      self.nrows())
852      return _pyval_field_major_to_node_major(
853          list(result.keys()), list(result.values()), self.rank)
854    else:
855      return result
856
857  @classmethod
858  def from_pyval(cls, pyval, typespec=None):
859    """Constructs a StructuredTensor from a nested Python structure.
860
861    >>> StructuredTensor.from_pyval(
862    ...     {'a': [1, 2, 3], 'b': [[4, 5], [6, 7]]})
863    <StructuredTensor(
864        fields={
865          "a": tf.Tensor([1 2 3], shape=(3,), dtype=int32),
866          "b": <tf.RaggedTensor [[4, 5], [6, 7]]>},
867        shape=())>
868
869    Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
870
871    Args:
872      pyval: The nested Python structure that should be used to create the new
873        `StructuredTensor`.
874      typespec: A `StructuredTensor.Spec` specifying the expected type for each
875        field. If not specified, then all nested dictionaries are turned into
876        StructuredTensors, and all nested lists are turned into Tensors (if
877        rank<2) or RaggedTensors (if rank>=2).
878
879    Returns:
880      A `StructuredTensor`.
881    """
882    return cls._from_pyval(pyval, typespec, ())
883
884  @classmethod
885  def _from_pyval(cls, pyval, typespec, path_so_far):
886    """Helper function for from_pyval.
887
888
889    Args:
890      pyval: The nested Python structure that should be used to create the new
891        `StructuredTensor`.
892      typespec: A `StructuredTensor.Spec` specifying the expected type for each
893        field. If not specified, then all nested dictionaries are turned into
894        StructuredTensors, and all nested lists are turned into Tensors (if
895        rank<2) or RaggedTensors (if rank>=2).
896      path_so_far: the path of fields that led here (for error messages).
897
898    Returns:
899      A `StructuredTensor`.
900    """
901    if isinstance(pyval, dict):
902      return cls._from_pydict(pyval, typespec, path_so_far)
903    elif isinstance(pyval, (list, tuple)):
904      keys = set()
905      rank = _pyval_find_struct_keys_and_depth(pyval, keys)
906      if rank is not None:
907        return cls._from_pylist_of_dict(pyval, keys, rank, typespec,
908                                        path_so_far)
909      else:
910        return cls._from_pylist_of_value(pyval, typespec, path_so_far)
911    else:
912      return cls._from_pyscalar(pyval, typespec, path_so_far)
913
914  @classmethod
915  def _from_pydict(cls, pyval, typespec, path_so_far):
916    """Converts python dictionary `pyval` to a StructuredTensor with rank=0."""
917    if typespec is None:
918      fields = dict((k, cls._from_pyval(v, None, path_so_far + (k,)))
919                    for (k, v) in pyval.items())
920    else:
921      spec_shape = typespec._shape  # pylint: disable=protected-access
922      field_specs = typespec._field_specs  # pylint: disable=protected-access
923      if not (isinstance(typespec, StructuredTensor.Spec) and
924              spec_shape.rank == 0 and set(pyval) == set(field_specs)):
925        raise ValueError('Value at %r does not match typespec: %r vs %r' %
926                         (path_so_far, pyval, typespec))
927      fields = dict((k, cls._from_pyval(v, field_specs[k], path_so_far + (k,)))
928                    for (k, v) in pyval.items())
929    return StructuredTensor.from_fields(fields=fields, shape=(), validate=False)
930
931  @classmethod
932  def _from_pylist_of_dict(cls, pyval, keys, rank, typespec, path_so_far):
933    """Converts python list `pyval` to a StructuredTensor with rank>1."""
934    fields = dict((key, []) for key in keys)
935    for child in pyval:
936      _pyval_update_fields(child, fields, 1)
937    if typespec is None:
938      shape = tensor_shape.TensorShape([None] * rank)
939      for (key, target) in fields.items():
940        fields[key] = cls._from_pyval(target, None, path_so_far + (key,))
941    else:
942      field_specs = typespec._fields  # pylint: disable=protected-access
943      if ((not isinstance(typespec, StructuredTensor.Spec)) or  # pylint: disable=superfluous-parens
944          (set(fields) - set(field_specs))):
945        raise ValueError('Value at %r does not match typespec: %r vs %r' %
946                         (path_so_far, pyval, typespec))
947      shape = typespec._shape
948      if shape.rank < rank:
949        raise ValueError('Value at %r does not match typespec (rank mismatch): '
950                         '%r vs %r' % (path_so_far, pyval, typespec))
951      for (key, spec) in field_specs.items():
952        fields[key] = cls._from_pyval(
953            fields.get(key, []), spec, path_so_far + (key,))
954    try:
955      if not fields and typespec is None:
956        # TODO(b/183245576): handle cases where the typespec is known
957        # but the dictionary is empty.
958        return StructuredTensor._from_pylist_of_empty_dict(pyval, rank)
959      return StructuredTensor.from_fields(
960          fields=fields, shape=shape, validate=False)
961    except Exception as exc:
962      raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
963
964  @classmethod
965  def _from_pylist_of_empty_dict(cls, pyval, rank):
966    """Converts a pylist of empty dictionaries to StructuredTensors."""
967    if rank == 0:
968      return StructuredTensor.from_fields(fields={}, shape=(), validate=False)
969    elif rank == 1:
970      nrows = len(pyval)
971      shape = (nrows,)
972      return StructuredTensor.from_fields(fields={}, shape=shape, nrows=nrows)
973    elif rank > 1:
974      ragged_zeros = ragged_factory_ops.constant(_dicts_to_zeros(pyval))
975      nrows = len(pyval)
976      shape = tensor_shape.TensorShape([len(pyval)] + ([None] * (rank - 1)))
977      return StructuredTensor.from_fields(
978          fields={},
979          shape=shape,
980          row_partitions=ragged_zeros._nested_row_partitions,  # pylint:disable=protected-access
981          nrows=nrows)
982
983  @classmethod
984  def _from_pylist_of_value(cls, pyval, typespec, path_so_far):
985    """Converts python list `pyval` to a Tensor or RaggedTensor with rank>1."""
986    if typespec is None:
987      try:
988        return ragged_factory_ops.constant(pyval)
989      except Exception as exc:
990        raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
991    elif isinstance(typespec, tensor_spec.TensorSpec):
992      try:
993        result = constant_op.constant(pyval, typespec.dtype)
994      except Exception as exc:
995        raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
996      if not typespec.shape.is_compatible_with(result.shape):
997        raise ValueError('Value at %r does not match typespec: %r vs %r' %
998                         (path_so_far, typespec, pyval))
999      return result
1000    elif isinstance(typespec, ragged_tensor.RaggedTensorSpec):
1001      # pylint: disable=protected-access
1002      try:
1003        return ragged_factory_ops.constant(
1004            pyval,
1005            dtype=typespec._dtype,
1006            ragged_rank=typespec._ragged_rank,
1007            row_splits_dtype=typespec._row_splits_dtype,
1008            inner_shape=typespec._shape[typespec._ragged_rank + 1:])
1009      except Exception as exc:
1010        raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
1011    elif isinstance(typespec, StructuredTensor.Spec):
1012      empty_rank = _pyval_empty_list_depth(pyval)
1013      if empty_rank is None:
1014        raise ValueError('Value at %r does not match typespec: %r vs %r' %
1015                         (path_so_far, typespec, pyval))
1016      else:
1017        return cls._from_pylist_of_dict(pyval, set(), empty_rank, typespec,
1018                                        path_so_far)
1019    else:
1020      raise ValueError('Value at %r does not match typespec: %r vs %r' %
1021                       (path_so_far, typespec, pyval))
1022
1023  @classmethod
1024  def _from_pyscalar(cls, pyval, typespec, path_so_far):
1025    """Converts python scalar value `pyval` to a Tensor."""
1026    if typespec is None:
1027      try:
1028        return constant_op.constant(pyval)
1029      except Exception as exc:
1030        raise ValueError('Error parsing path %r' % (path_so_far,)) from exc
1031    else:
1032      if not (isinstance(typespec, tensor_spec.TensorSpec) and
1033              typespec.shape.rank == 0):
1034        raise ValueError('Value at %r does not match typespec: %r vs %r' %
1035                         (path_so_far, typespec, pyval))
1036      # TODO(edloper): Check that typespec.shape matches.
1037      return constant_op.constant(pyval, typespec.dtype)
1038
1039  #=============================================================================
1040  # Transforms
1041  #=============================================================================
1042
1043  # TODO(edloper): Add a 'validate' option here?
1044  # TODO(edloper): Unify nomenclature with RaggedTensor.  Should RaggedTensor
1045  # have a partition_outer_dimension method?
1046  def partition_outer_dimension(self, row_partition):
1047    """Partitions the outer dimension of this StructuredTensor.
1048
1049    Returns a new `StructuredTensor` with the same values as `self`, where
1050    the outer dimension is partitioned into two (possibly ragged) dimensions.
1051    Requires that this StructuredTensor have an outer dimension (i.e.,
1052    `self.shape.rank > 0`).
1053
1054    >>> st = StructuredTensor.from_pyval(
1055    ...     [{'foo': 12}, {'foo': 33}, {'foo': 99}])
1056    >>> partition = RowPartition.from_row_lengths([2, 0, 1])
1057    >>> st.partition_outer_dimension(partition)
1058    <StructuredTensor(
1059      fields={
1060        "foo": <tf.RaggedTensor [[12, 33], [], [99]]>},
1061      shape=(3, None))>
1062
1063    Args:
1064      row_partition: A `RowPartition`.
1065
1066    Returns:
1067      A `StructuredTensor` with rank `values.rank + 1`.
1068    """
1069    if not isinstance(row_partition, RowPartition):
1070      raise TypeError('row_partition must be a RowPartition.')
1071    if self.shape.rank == 0:
1072      raise ValueError('Shape %s must have rank at least 1' % self.shape)
1073    return _partition_outer_dimension(self, row_partition)
1074
1075  def merge_dims(self, outer_axis, inner_axis):
1076    """Merges outer_axis...inner_axis into a single dimension.
1077
1078    Returns a copy of this RaggedTensor with the specified range of dimensions
1079    flattened into a single dimension, with elements in row-major order.
1080
1081    >>> st = StructuredTensor.from_pyval(
1082    ...     [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]])
1083    >>> st.merge_dims(0, 1)
1084    <StructuredTensor(
1085      fields={
1086        "foo": tf.Tensor([12 33 99], shape=(3,), dtype=int32)},
1087      shape=(3,))>
1088
1089    Args:
1090      outer_axis: `int`: The first dimension in the range of dimensions to
1091        merge. May be negative (to index from the last dimension).
1092      inner_axis: `int`: The last dimension in the range of dimensions to merge.
1093        May be negative (to index from the last dimension).
1094
1095    Returns:
1096      A copy of this tensor, with the specified dimensions merged into a
1097      single dimension.  The shape of the returned tensor will be
1098      `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
1099      is the total number of slices in the merged dimensions.
1100    """
1101    outer_axis = array_ops.get_positive_axis(
1102        outer_axis,
1103        self.shape.rank,
1104        axis_name='outer_axis',
1105        ndims_name='rank(self)')
1106    inner_axis = array_ops.get_positive_axis(
1107        inner_axis,
1108        self.shape.rank,
1109        axis_name='inner_axis',
1110        ndims_name='rank(self)')
1111    if not outer_axis <= inner_axis:
1112      raise ValueError('Expected outer_axis (%d) to be less than or equal to '
1113                       'inner_axis (%d)' % (outer_axis, inner_axis))
1114    return _merge_dims(self, outer_axis, inner_axis)
1115
1116  class Spec:
1117    """A spec for StructuredTensor."""
1118
1119    def __validate__(self):
1120      assert self._ragged_shape is not None
1121
1122    @classmethod
1123    def _from_fields_and_rank(cls, fields, rank):
1124      """Creates a spec of a StructuredTensor with fields and rank."""
1125      shape = None
1126      for (k, v) in fields.items():
1127        field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v)
1128        if field_shape_untruncated is None:
1129          raise ValueError(f'Cannot convert spec of {k}.')
1130        untruncated_rank = field_shape_untruncated.rank
1131        if (untruncated_rank is not None
1132            and untruncated_rank < rank):
1133          raise ValueError(
1134              f'Rank of field {k} is {untruncated_rank}, '
1135              f'but must be at least {rank}.')
1136        field_shape = field_shape_untruncated._truncate(rank)  # pylint: disable=protected-access
1137        if shape is None:
1138          shape = field_shape
1139        else:
1140          shape = shape._merge_with(field_shape)
1141      return StructuredTensor.Spec(_ragged_shape=shape, _fields=fields)
1142
1143    @classmethod
1144    def _from_shape(
1145        cls, shape: dynamic_ragged_shape.DynamicRaggedShape
1146    ) -> 'StructuredTensor.Spec':
1147      """Creates the spec of an empty StructuredTensor."""
1148      return StructuredTensor.Spec(_ragged_shape=shape, _fields={})
1149
1150    # For backwards compatibility
1151    @property
1152    def _shape(self) -> tensor_shape.TensorShape:
1153      return self._ragged_shape._to_tensor_shape()  # pylint: disable=protected-access
1154
1155    # For backwards compatibility
1156    @property
1157    def _field_specs(self) -> Dict[str, type_spec.TypeSpec]:
1158      return self._fields
1159
1160    # For backwards compatibility
1161    @property
1162    def shape(self) -> tensor_shape.TensorShape:
1163      return self._shape
1164
1165    # For backwards compatibility
1166    @property
1167    def rank(self):
1168      return self._ragged_shape.rank
1169
1170
1171# Regular expression used to determine whether a string is a valid field name.
1172# Note: we plan to relax (or possibly eliminate) this in the future; you
1173# should not rely on the fact that some field names are currently disallowed.
1174_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$')
1175
1176#=============================================================================
1177# Helper funtions
1178#=============================================================================
1179# TODO(edloper): Move some of these helpers to row_partition.py?
1180
1181
1182def _convert_to_structured_field_value(value):
1183  """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor."""
1184  if isinstance(value,
1185                (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
1186    return value
1187  elif ragged_tensor.is_ragged(value):
1188    return ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
1189  elif isinstance(value, extension_type.ExtensionType):
1190    return value
1191  else:
1192    try:
1193      return ops.convert_to_tensor(value)
1194    except (ValueError, TypeError) as e:
1195      raise TypeError('Unexpected type for value in `fields`: %r' %
1196                      value) from e
1197
1198
1199def _find_shape_dtype(fields, nrows, row_partitions):
1200  """Return a consistent dtype for fields, nrows, & row_partitions."""
1201  field_dtypes = dict()
1202  for (key, value) in fields.items():
1203    if isinstance(value, ragged_tensor.RaggedTensor):
1204      field_dtypes[key] = value.row_splits.dtype
1205    elif isinstance(value, StructuredTensor) and value.rank > 0:
1206      field_dtypes[key] = value.nrows().dtype
1207
1208  field_dtype = None
1209  for value in field_dtypes.values():
1210    if field_dtype is None:
1211      field_dtype = value
1212    elif field_dtype != value:
1213      raise ValueError('field values have incompatible row_partition dtypes. ' +
1214                       f'field_dtypes: {field_dtypes}')
1215
1216  row_partition_dtype = None
1217  row_partition_dtypes = []
1218  if row_partitions is not None:
1219    row_partition_dtypes = [rp.dtype for rp in row_partitions]
1220    for rp_dtype in row_partition_dtypes:
1221      if row_partition_dtype is None:
1222        row_partition_dtype = rp_dtype
1223      elif row_partition_dtype != rp_dtype:
1224        raise ValueError('row_partitions have incompatible dtypes with '
1225                         f'themselves:{row_partition_dtypes}')
1226
1227  nrows_dtype = None
1228  if isinstance(nrows, ops.Tensor):
1229    nrows_dtype = nrows.dtype
1230  all_dtypes = filter(lambda x: x is not None,
1231                      [field_dtype, row_partition_dtype, nrows_dtype])
1232  shape_dtypes = set()
1233  shape_dtypes.update(all_dtypes)
1234  if len(shape_dtypes) > 1:
1235    raise ValueError('row_partition dtypes are inconsistent: ' +
1236                     f'field_dtype:{field_dtype} ' +
1237                     f'row_partition_dtype:{row_partition_dtype} ' +
1238                     f'nrows_dtype:{nrows_dtype}')
1239  elif shape_dtypes:
1240    return shape_dtypes.pop()
1241  else:
1242    return dtypes.int64
1243
1244
1245def _merge_nrows(nrows, static_nrows, value, dtype, validate):
1246  """Merges `nrows` with `nrows(value)`.
1247
1248  Checks that `value` has the expected number of rows (`nrows`), and returns
1249  `nrows`.  If `validate` is true, then add validation ops that check that
1250  the `nrows` values match.
1251
1252  Args:
1253    nrows: scalar integer Tensor.
1254    static_nrows: tf.Dimension: static value of nrows, if known.
1255    value: Tensor or RaggedTensor or StructuredTensor
1256    dtype: dtype for `nrows`.
1257    validate: bool -- whether to add validation ops.
1258
1259  Returns:
1260    A tuple `(nrows, static_nrows)`.
1261  """
1262  static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0)
1263  if isinstance(value, ops.Tensor):
1264    value_nrows = array_ops.shape(value, out_type=dtype)[0]
1265  else:
1266    value_nrows = value.nrows()
1267  if nrows is None:
1268    nrows = value_nrows
1269  elif (static_value_nrows.value is not None and
1270        static_nrows.value is not None):
1271    if not static_value_nrows.is_compatible_with(static_nrows):
1272      raise ValueError('fields have incompatible nrows')
1273    nrows = value_nrows  # No need to add an assertion op.
1274  elif validate:
1275    nrows = control_flow_ops.with_dependencies([
1276        check_ops.assert_equal(
1277            nrows, value_nrows, message='fields have incompatible nrows')
1278    ], nrows)
1279  return nrows, static_nrows._merge_with(static_value_nrows)  # pylint: disable=protected-access
1280
1281
1282def _merge_row_partitions(row_partitions, value, rank, dtype, validate):
1283  """Merges `row_partitions` with `row_partitions(value)`."""
1284  if isinstance(value, ops.Tensor):
1285    value_row_partitions = _row_partitions_for_tensor(value, rank, dtype)
1286
1287  elif isinstance(value, ragged_tensor.RaggedTensor):
1288    value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype)
1289
1290  else:
1291    assert isinstance(value, StructuredTensor), type(value)
1292    value_row_partitions = value.row_partitions[:rank - 1]
1293
1294  assert len(value_row_partitions) == rank - 1
1295  if row_partitions is None:
1296    return tuple(value_row_partitions)
1297  else:
1298    return tuple([
1299        p1._merge_precomputed_encodings(p2, validate)  # pylint: disable=protected-access
1300        for (p1, p2) in zip(row_partitions, value_row_partitions)
1301    ])
1302
1303
1304def _row_partitions_for_tensor(value, rank, dtype):
1305  """Returns the row partitions for a tf.Tensor."""
1306  shape = array_ops.shape(value, out_type=dtype)
1307  return _row_partitions_for_uniform_shape(shape, rank)
1308
1309
1310def _row_partitions_for_ragged_tensor(value, rank, dtype):
1311  """Returns the row partitions for a tf.RaggedTensor."""
1312  assert rank > 1
1313  value_row_partitions = value._nested_row_partitions[:rank - 1]  # pylint: disable=protected-access
1314  if len(value_row_partitions) < (rank - 1):
1315    value_row_partitions += _row_partitions_for_tensor(
1316        value.flat_values, rank - len(value_row_partitions), dtype)
1317  assert len(value_row_partitions) == rank - 1
1318  return value_row_partitions
1319
1320
1321def _row_partitions_for_uniform_shape(shape, rank):
1322  """Returns row partitions for the given shape Tensor.
1323
1324  Args:
1325    shape: A vector describing a uniform shape.
1326    rank: The number of dimensions to generate row partitions for
1327
1328  Returns:
1329    A list of (rank-1) `RowPartition`s with uniform row length.
1330  """
1331  shape_cumprod = math_ops.cumprod(shape[:rank])
1332  # pylint: disable=g-complex-comprehension
1333  return tuple([
1334      RowPartition.from_uniform_row_length(
1335          uniform_row_length=shape[i + 1],
1336          nvals=shape_cumprod[i + 1],
1337          nrows=shape_cumprod[i]) for i in range(rank - 1)
1338  ])
1339
1340
1341def _pyval_field_major_to_node_major(keys, values, depth):
1342  """Regroup each field (k, v) from dict-of-list to list-of-dict.
1343
1344  Given a "field-major" encoding of the StructuredTensor (which maps each key to
1345  a single nested list containing the values for all structs), return a
1346  corresponding "node-major" encoding, consisting of a nested list of dicts.
1347
1348  Args:
1349    keys: The field names (list of string).  Must not be empty.
1350    values: The field values (list of python values).  Must have the same length
1351      as `keys`.
1352    depth: The list depth at which dictionaries should be created.
1353
1354  Returns:
1355    A nested list of dict, with depth `depth`.
1356  """
1357  assert keys
1358  if depth == 0:
1359    return dict(zip(keys, values))
1360  nvals = len(values[0])
1361  assert all(nvals == len(values[i]) for i in range(1, len(values)))
1362  return [
1363      _pyval_field_major_to_node_major(keys, value_slice, depth - 1)
1364      for value_slice in zip(*values)
1365  ]
1366
1367
1368def _empty_dict_pylist_from_row_partitions(row_partitions, nrows):
1369  """Returns a python list of empty dicts from the given row partitions.
1370
1371  Args:
1372    row_partitions: The row-partitions describing the ragged shape of the
1373      result.
1374    nrows: The number of rows in the outermost row-partition.  (Or if
1375      `len(row_partitions)==0`, then the number of empty dicts to return.)
1376
1377  Returns:
1378    A nested python list whose leaves (if any) are empty python dicts.
1379  """
1380  if not row_partitions:
1381    return [{} for _ in range(nrows)]
1382  else:
1383    values = _empty_dict_pylist_from_row_partitions(
1384        row_partitions[1:], row_partitions[0].row_splits()[-1])
1385    splits = row_partitions[0].row_splits()
1386    return [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
1387
1388
1389def _pyval_find_struct_keys_and_depth(pyval, keys):
1390  """Finds the keys & depth of nested dictionaries in `pyval`.
1391
1392  Args:
1393    pyval: A nested structure of lists, tuples, and dictionaries.
1394    keys: (output parameter) A set, which will be updated with any keys that are
1395      found in the nested dictionaries.
1396
1397  Returns:
1398    The nesting depth of dictionaries in `pyval`, or `None` if `pyval` does
1399    not contain any dictionaries.
1400  Raises:
1401    ValueError: If dictionaries have inconsistent depth.
1402  """
1403  if isinstance(pyval, dict):
1404    keys.update(pyval.keys())
1405    return 0
1406  elif isinstance(pyval, (list, tuple)):
1407    depth = None
1408    for child in pyval:
1409      child_depth = _pyval_find_struct_keys_and_depth(child, keys)
1410      if child_depth is not None:
1411        if depth is None:
1412          depth = child_depth + 1
1413        elif depth != child_depth + 1:
1414          raise ValueError('Inconsistent depth of dictionaries')
1415    return depth
1416  else:
1417    return None
1418
1419
1420def _pyval_update_fields(pyval, fields, depth):
1421  """Append the field values from `pyval` to `fields`.
1422
1423  Args:
1424    pyval: A python `dict`, or nested list/tuple of `dict`, whose value(s)
1425      should be appended to `fields`.
1426    fields: A dictionary mapping string keys to field values.  Field values
1427      extracted from `pyval` are appended to this dictionary's values.
1428    depth: The depth at which `pyval` should be appended to the field values.
1429  """
1430  if not isinstance(pyval, (dict, list, tuple)):
1431    raise ValueError('Expected dict or nested list/tuple of dict')
1432
1433  for (key, target) in fields.items():
1434    for _ in range(1, depth):
1435      target = target[-1]
1436    target.append(pyval[key] if isinstance(pyval, dict) else [])
1437
1438  if isinstance(pyval, (list, tuple)):
1439    for child in pyval:
1440      _pyval_update_fields(child, fields, depth + 1)
1441
1442
1443def _pyval_empty_list_depth(pyval):
1444  """Find the max depth for nested empty lists.
1445
1446  Args:
1447    pyval: A nested python list.
1448
1449  Returns:
1450    The maximum depth of empty lists in `pyval`, or None if `pyval` contains
1451    anything other than nested empty lists.
1452  """
1453  if isinstance(pyval, list):
1454    if not pyval:
1455      return 1
1456    depths = [_pyval_empty_list_depth(v) for v in pyval]
1457    if any(depth is None for depth in depths):
1458      return None
1459    else:
1460      return max(depths) + 1
1461  else:
1462    return None
1463
1464
1465def _replace_row_partitions(value, new_partitions):
1466  """Updates `value` to use `new_partitions` as its (outer) row partitions.
1467
1468  This is used to ensure that all fields in a `StructuredTensor` use identical
1469  `RowPartition` objects for the shared dimensions.  In particular,
1470  `StructuredTensor.from_fields` first merges all of the row partitions from
1471  any fields, and then replaces the outer row partitions of all fields with
1472  the merged row partitions (using this function).
1473
1474  Args:
1475    value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`.
1476    new_partitions: A list of row-partitions that should be used by `value`.
1477      Must be equivalent to `value`'s current row partitions.
1478
1479  Returns:
1480    A value that is equivalent to `value`, where outer row partitions have been
1481    replaced by `new_partitions`.
1482  """
1483  if isinstance(value, ops.Tensor) or not new_partitions:
1484    return value
1485
1486  elif isinstance(value, ragged_tensor.RaggedTensor):
1487    return ragged_tensor.RaggedTensor._from_row_partition(  # pylint: disable=protected-access
1488        values=_replace_row_partitions(value.values, new_partitions[1:]),
1489        row_partition=new_partitions[0])
1490
1491  else:
1492    assert isinstance(value, StructuredTensor)
1493    new_fields = dict((k, _replace_row_partitions(v, new_partitions))
1494                      for (k, v) in value._fields.items())
1495    return StructuredTensor._old_init(  # pylint: disable=protected-access
1496        fields=new_fields,
1497        shape=value.shape,
1498        nrows=value.nrows(),
1499        row_partitions=tuple(new_partitions) +
1500        tuple(value.row_partitions[len(new_partitions):]))
1501
1502
1503def _partition_outer_dimension(value, row_partition):
1504  """Partitions the outer dimension of `value` using `row_partitions`.
1505
1506  Examples:
1507
1508    >>> partition = RowPartition.from_row_lengths([2, 0, 1])
1509    >>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition)
1510    <tf.RaggedTensor [[1, 2], [], [3]]>
1511
1512    >>> struct_value = StructuredTensor.from_pyval(
1513    ...     [{'x': 1}, {'x': 2}, {'x': 3}])
1514    >>> _partition_outer_dimension(struct_value, partition)
1515    <StructuredTensor(
1516      fields={
1517        "x": <tf.RaggedTensor [[1, 2], [], [3]]>},
1518      shape=(3, None))>
1519
1520  Args:
1521    value: Tensor, RaggedTensor, or StructuredTensor
1522    row_partition: RowPartition
1523
1524  Returns:
1525    A value with the same type as `value`, where
1526    `result.rank = value.rank + 1`.
1527  """
1528  is_ragged = row_partition.uniform_row_length() is None
1529  if isinstance(value, ops.Tensor) and not is_ragged:
1530    new_shape = array_ops.concat(
1531        [[row_partition.nrows(),
1532          row_partition.uniform_row_length()],
1533         array_ops.shape(value, out_type=row_partition.dtype)[1:]],
1534        axis=0)
1535    return array_ops.reshape(value, new_shape)
1536  elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
1537    return ragged_tensor.RaggedTensor._from_row_partition(  # pylint: disable=protected-access
1538        value, row_partition)
1539  else:
1540    assert isinstance(value, StructuredTensor)
1541    nrows = row_partition.static_nrows
1542    ncols = row_partition.static_uniform_row_length
1543    shape = tensor_shape.TensorShape([nrows,
1544                                      ncols]).concatenate(value.shape[1:])
1545    fields = dict((k, _partition_outer_dimension(v, row_partition))
1546                  for (k, v) in value._fields.items())
1547    return StructuredTensor._old_init(  # pylint: disable=protected-access
1548        fields,
1549        shape,
1550        row_partition.nrows(), (row_partition,) + value.row_partitions)
1551
1552
1553def _merge_dims(value, outer_axis, inner_axis):
1554  """Merges `outer_axis...inner_axis` of `value` into a single dimension."""
1555  assert outer_axis < inner_axis
1556  if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
1557    return ragged_tensor.merge_dims(value, outer_axis, inner_axis)
1558  else:
1559    assert isinstance(value, StructuredTensor)
1560    fields = dict((k, _merge_dims(v, outer_axis, inner_axis))
1561                  for (k, v) in value._fields.items())
1562    ragged_shape = value._ragged_shape._merge_dims(  # pylint: disable=protected-access
1563        outer_axis, inner_axis)
1564    return StructuredTensor(fields, ragged_shape)
1565
1566
1567_structured_tensor_factory_key = object()  # unique private object
1568
1569
1570def _dynamic_ragged_shape_spec_from_spec(
1571    spec: Union[dynamic_ragged_shape.DynamicRaggedShape.Spec,
1572                ragged_tensor.RaggedTensorSpec, StructuredTensor.Spec,
1573                tensor_spec.TensorSpec]
1574) -> dynamic_ragged_shape.DynamicRaggedShape.Spec:
1575  if isinstance(spec, StructuredTensor.Spec):
1576    return spec._ragged_shape  # pylint: disable=protected-access
1577  else:
1578    return dynamic_ragged_shape.DynamicRaggedShape.Spec._from_spec(spec)  # pylint: disable=protected-access
1579
1580
1581def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]:
1582  """FieldName can be given also as string, this normalizes it to a tuple."""
1583  if isinstance(name, str):
1584    return (name,)
1585  if isinstance(name, list):
1586    return tuple(name)
1587  assert isinstance(name, tuple)
1588  return name
1589
1590
1591def _dicts_to_zeros(pyval):
1592  """Replaces dictionaries zeros in a pylist."""
1593  if isinstance(pyval, dict):
1594    return 0
1595  return [_dicts_to_zeros(x) for x in pyval]
1596
1597
1598def _merge_dims_generic(source, outer, inner):
1599  """Merges outer_axis...inner_axis into a single dimension.
1600
1601  If outer == inner, this is a NOOP. If inner < outer, then this fials.
1602  If inner >= source.shape.rank, then the behavior is undefined.
1603
1604  Args:
1605    source: a tensor, ragged tensor, or structured tensor.
1606    outer: a python int, indicating the first dimension to compress (must be
1607      nonnegative).
1608    inner: a python int, indicating the first dimension to keep (of the tail)
1609      (must be nonnegative).
1610
1611  Returns:
1612    source with outer_axis...inner_axis merged into a single dimension.
1613
1614  """
1615  if isinstance(source, StructuredTensor):
1616    return source.merge_dims(outer, inner)
1617  else:
1618    return ragged_tensor.merge_dims(source, outer, inner)
1619
1620
1621def _dynamic_ragged_shape_from_tensor(
1622    field, dtype=None) -> dynamic_ragged_shape.DynamicRaggedShape:
1623  """Extension of DynamicRaggedShape.from_tensor to support StructuredTensor."""
1624  if isinstance(field, StructuredTensor):
1625    return field._ragged_shape  # pylint: disable=protected-access
1626  shape = array_ops.shape_v2(field, out_type=dtype)
1627
1628  if isinstance(shape, ops.Tensor):
1629    return dynamic_ragged_shape.DynamicRaggedShape(
1630        row_partitions=[],
1631        inner_shape=shape)
1632  elif isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape):
1633    return shape
1634  # TODO(martinz): add a test for the following line.
1635  raise TypeError(f'Expected shape tf.shape({field}) to return a Tensor or a '
1636                  f'DynamicRaggedShape. Instead, got: {shape}.')
1637
1638
1639def _merge_with_optional(
1640    a: Optional[dynamic_ragged_shape.DynamicRaggedShape],
1641    b: Optional[dynamic_ragged_shape.DynamicRaggedShape]
1642    ) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]:
1643  if a is None:
1644    return b
1645  if b is None:
1646    return a
1647  return a._merge_with(b)  # pylint: disable=protected-access
1648
1649
1650def _shape_from_fields(
1651    fields, rank: int,
1652    dtype: dtypes.DType) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]:
1653  """Given fields, rank, and dtype, create a shape."""
1654
1655  field_shape = None
1656  for (k, field) in fields.items():
1657    try:
1658      next_field_shape_raw = _dynamic_ragged_shape_from_tensor(
1659          field, dtype=dtype)
1660      next_field_shape = next_field_shape_raw[:rank]
1661      field_shape = _merge_with_optional(field_shape, next_field_shape)
1662    except Exception as err:
1663      raise ValueError(f'Error in shape of {k}') from err
1664
1665  return field_shape
1666
1667
1668# pylint:disable=protected-access
1669def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions):
1670  """Produce a DynamicRaggedShape for StructuredTensor."""
1671  assert isinstance(fields, dict), fields
1672  assert isinstance(shape, tensor_shape.TensorShape), shape
1673  assert nrows is None or isinstance(nrows, ops.Tensor) or isinstance(
1674      nrows, int), nrows
1675  assert row_partitions is None or isinstance(row_partitions,
1676                                              tuple), row_partitions
1677  rank = shape.rank
1678
1679  if rank is None:
1680    raise TypeError("StructuredTensor's shape must have known rank.")
1681
1682  # TODO(martinz): figure out whether to validate.
1683  dtype = _find_shape_dtype(fields, nrows, row_partitions)
1684  result = None
1685  if shape.is_fully_defined():
1686    result = dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape(
1687        shape.as_list(), dtype=dtype)
1688
1689  if rank == 0:
1690    return dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape(
1691        array_ops.zeros((0,), dtype=dtype))
1692
1693  result = _merge_with_optional(result, _shape_from_fields(fields, rank, dtype))
1694  if rank == 1:
1695    alt_value = tensor_shape.dimension_value(shape[0])
1696    if alt_value is not None:
1697      nrows = alt_value
1698    if nrows is not None:
1699      result = _merge_with_optional(
1700          result,
1701          dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape(
1702              [nrows], dtype=dtype))
1703    if result is None:
1704      raise ValueError('Must specify `nrows`, a fully specified `shape`,' +
1705                       ' or have `fields` if `rank=1`')
1706
1707    return result
1708
1709  if row_partitions:
1710    result = _merge_with_optional(
1711        result, dynamic_ragged_shape.DynamicRaggedShape.from_row_partitions(
1712            row_partitions, dtype=dtype))
1713
1714  if result is None:
1715    raise ValueError('Must specify row_partitions, a fully specified shape, ' +
1716                     'or have fields if rank > 1')
1717  return result
1718
1719
1720# TODO(martinz): Drop this method or rename.
1721def StructuredTensorSpec(shape, field_specs):  # pylint:disable=invalid-name
1722  """A placeholder for the old StructuredTensorSpec."""
1723  if not isinstance(field_specs, dict):
1724    raise TypeError('field_specs must be a dictionary.')
1725  for k in field_specs.keys():
1726    if not isinstance(k, str):
1727      raise TypeError('field_specs must be a dictionary with string keys.')
1728  for v in field_specs.values():
1729    if not isinstance(v, type_spec.TypeSpec):
1730      raise TypeError('field_specs must be a dictionary with TypeSpec values.')
1731
1732  shape = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape(
1733      tensor_shape.as_shape(shape),
1734      0,
1735      dtypes.int32)
1736  rank = shape.rank
1737  if rank is None:
1738    raise TypeError("StructuredTensor's shape must have known rank.")
1739  for (k, v) in field_specs.items():
1740    field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v)
1741    if field_shape_untruncated is None:
1742      raise ValueError(f'Cannot convert spec of {k}.')
1743    untruncated_rank = field_shape_untruncated.rank
1744    if (untruncated_rank is not None
1745        and untruncated_rank < rank):
1746      raise ValueError(
1747          f'Rank of field {k} is {untruncated_rank},'
1748          f' but must be at least {rank}.')
1749    field_shape = field_shape_untruncated._truncate(rank)
1750    shape = shape._merge_with(field_shape)
1751  return StructuredTensor.Spec(_ragged_shape=shape, _fields=field_specs)
1752