xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/structured/structured_array_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""StructuredTensor array ops."""
16
17from typing import Sequence
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import random_ops
25from tensorflow.python.ops.ragged import dynamic_ragged_shape
26from tensorflow.python.ops.ragged import ragged_tensor
27from tensorflow.python.ops.ragged.row_partition import RowPartition
28from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
29from tensorflow.python.util import deprecation
30from tensorflow.python.util import dispatch
31
32
33@dispatch.dispatch_for_api(array_ops.shape_v2)
34def shape_v2(input: StructuredTensor, out_type=dtypes.int32,  # pylint: disable=redefined-builtin
35             name=None) -> dynamic_ragged_shape.DynamicRaggedShape:
36  """Returns a DynamicRaggedShape containing the shape of the input."""
37  del name
38  return input._ragged_shape.with_dtype(out_type)  # pylint: disable=protected-access
39
40
41@dispatch.dispatch_for_api(array_ops.shape)
42def shape_v1(input: StructuredTensor, name=None,  # pylint: disable=redefined-builtin
43             out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape:
44  """Returns a DynamicRaggedShape containing the shape of the input."""
45  del name
46  return input._ragged_shape.with_dtype(out_type)  # pylint: disable=protected-access
47
48
49@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor)
50@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim')
51def expand_dims(input, axis=None, name=None, dim=None):  # pylint: disable=redefined-builtin
52  """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
53
54  This is an implementation of tf.expand_dims for StructuredTensor. Note
55  that the `axis` must be less than or equal to rank.
56
57  >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
58  >>> tf.expand_dims(st, 0).to_pyval()
59  [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
60  >>> tf.expand_dims(st, 1).to_pyval()
61  [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
62  >>> tf.expand_dims(st, 2).to_pyval()
63  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
64  >>> tf.expand_dims(st, -1).to_pyval()  # -1 is the same as 2
65  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
66
67  Args:
68    input: the original StructuredTensor.
69    axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
70    name: the name of the op.
71    dim: deprecated: use axis.
72
73  Returns:
74    a new structured tensor with larger rank.
75
76  Raises:
77    an error if `axis < -(rank + 1)` or `rank < axis`.
78  """
79  axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim)
80  return _expand_dims_impl(input, axis, name=name)
81
82
83@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor)
84def expand_dims_v2(input, axis, name=None):  # pylint: disable=redefined-builtin
85  """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
86
87  This is an implementation of tf.expand_dims for StructuredTensor. Note
88  that the `axis` must be less than or equal to rank.
89
90  >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
91  >>> tf.expand_dims(st, 0).to_pyval()
92  [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
93  >>> tf.expand_dims(st, 1).to_pyval()
94  [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
95  >>> tf.expand_dims(st, 2).to_pyval()
96  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
97  >>> tf.expand_dims(st, -1).to_pyval()  # -1 is the same as 2
98  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
99
100  Args:
101    input: the original StructuredTensor.
102    axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
103    name: the name of the op.
104
105  Returns:
106    a new structured tensor with larger rank.
107
108  Raises:
109    an error if `axis < -(rank + 1)` or `rank < axis`.
110  """
111  return _expand_dims_impl(input, axis, name=name)
112
113
114@dispatch.dispatch_for_types(array_ops.gather, StructuredTensor)
115def gather(params,
116           indices,
117           validate_indices=None,
118           name=None,
119           axis=None,
120           batch_dims=0):
121  """tf.gather for structured tensors.
122
123  Does not support (yet) checks on illegal axis values, et cetera.
124
125  Indices must be a ragged or dense tensor.
126  Args:
127    params: a structured tensor to be gathered
128    indices: a ragged tensor or tensor to gather by.
129    validate_indices: whether to validate the indices
130    name: the name of the op(s).
131    axis: the axis in params to gather on.
132    batch_dims: the number of batch dimensions.
133
134  Returns:
135    the params reorganized according to indices.
136  """
137  if name is None:
138    name = 'gather'
139  with ops.name_scope(name):
140    if axis is None:
141      axis = batch_dims
142    axis = array_ops.get_positive_axis(axis, params.shape.rank,
143                                       ndims_name='params.shape.rank')
144    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
145        indices, name='indices')
146
147    def leaf_op(p):
148      return array_ops.gather(
149          p,
150          indices,
151          validate_indices=validate_indices,
152          axis=axis,
153          batch_dims=batch_dims,
154          name=None)
155
156    return _extend_op_single(params, leaf_op)
157
158
159@dispatch.dispatch_for_types(array_ops.concat, StructuredTensor)
160def concat(values, axis, name: str = 'concat'):
161  """tf.concat for structured tensors.
162
163  Does not support (yet) checks on illegal axis values, et cetera.
164
165  Args:
166    values: a sequence of StructuredTensors.
167    axis: an axis to concatenate upon.
168    name: the name of the op(s).
169
170  Returns:
171    the params reorganized according to indices.
172  """
173  if name is None:
174    name = 'concat'
175  _assert_concat_compatible_structured_tensors(values)
176  def leaf_op(values):
177    return array_ops.concat(values, axis)
178  # TODO(martinz): handle axis when it is a tensor.
179  axis = array_ops.get_positive_axis(axis, values[0].rank)
180  with ops.name_scope(name, 'StructuredConcat', values):
181    return _extend_op(values, leaf_op)
182
183
184@dispatch.dispatch_for_types(random_ops.random_shuffle, StructuredTensor)
185def random_shuffle(value, seed=None, name=None):
186  """Shuffle a structured tensor on the zeroth axis.
187
188  Args:
189    value: a structured tensor of rank at least one.
190    seed: the seed for shuffling.
191    name: the name for shuffle.
192
193  Returns:
194    The shuffled structured tensor.
195  """
196  with ops.name_scope(name, 'shuffle', [value, seed]):
197    if value.rank == 0:
198      raise ValueError('Cannot shuffle a scalar StructuredTensor')
199    first_dimension = value.nrows()
200    index = random_ops.random_shuffle(math_ops.range(first_dimension),
201                                      seed=seed)
202    return gather(value, index, axis=0)
203
204
205@dispatch.dispatch_for_types(array_ops.size_v2, StructuredTensor)
206def size_v2(input, out_type=dtypes.int32, name=None):
207  # pylint: disable=redefined-builtin
208  """Returns the size of a tensor."""
209  return size(input, name=name, out_type=out_type)
210
211
212# pylint: disable=protected-access
213@dispatch.dispatch_for_types(array_ops.size, StructuredTensor)
214def size(input, name=None, out_type=dtypes.int32):
215  # pylint: disable=redefined-builtin
216  """Returns the size of a tensor."""
217  with ops.name_scope(name, 'size', [input]) as name:
218    if not input.row_partitions:
219      if input.nrows() is not None:
220        return math_ops.cast(input.nrows(), out_type)  # vector.
221      else:
222        return math_ops.cast(1, out_type)  # scalar.
223    # 2D and up.
224    nvals = input.row_partitions[-1].nvals()
225    if nvals is None or out_type is None:
226      return nvals
227    return math_ops.cast(nvals, dtype=out_type)
228
229
230# pylint: disable=protected-access
231@dispatch.dispatch_for_types(array_ops.zeros_like, StructuredTensor)
232def zeros_like(tensor, dtype=None, name=None, optimize=True):
233  """Implementation of zeros_like for StructuredTensor for TF v1."""
234  del optimize
235  return zeros_like_v2(tensor, dtype=dtype, name=name)
236
237
238# pylint: disable=protected-access
239@dispatch.dispatch_for_types(array_ops.zeros_like_v2, StructuredTensor)
240def zeros_like_v2(input, dtype=None, name=None):  # pylint: disable=redefined-builtin
241  """Replace every object with a zero.
242
243  Example:
244  >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}])
245  >>> tf.zeros_like(st)
246  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0.0, 0.0], dtype=float32)>
247  >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]])
248  >>> tf.zeros_like(st, dtype=tf.int32)
249  <tf.RaggedTensor [[0], [0, 0]]>
250
251  Args:
252    input: a structured tensor.
253    dtype: the dtype of the resulting zeros. (default is tf.float32)
254    name: a name for the op.
255  Returns:
256    a tensor of zeros of the same shape.
257  """
258  if dtype is None:
259    dtype = dtypes.float32
260  with ops.name_scope(name, 'zeros_like', [input]) as name:
261    if not input.row_partitions:
262      if input.nrows() is not None:
263        return array_ops.zeros([input.nrows()], dtype)  # vector.
264      else:
265        return array_ops.zeros([], dtype)  # scalar.
266    # 2D and up.
267    last_row_partition = input.row_partitions[-1]
268
269    result = ragged_tensor.RaggedTensor._from_nested_row_partitions(
270        array_ops.zeros(last_row_partition.nvals(), dtype=dtype),
271        input.row_partitions)
272    return result
273
274
275# pylint: disable=protected-access
276@dispatch.dispatch_for_types(array_ops.ones_like, StructuredTensor)
277def ones_like(tensor, dtype=None, name=None, optimize=True):
278  """Implementation of zeros_like for StructuredTensor for TF v1."""
279  del optimize
280  return ones_like_v2(tensor, dtype=dtype, name=name)
281
282
283# pylint: disable=protected-access
284@dispatch.dispatch_for_types(array_ops.ones_like_v2, StructuredTensor)
285def ones_like_v2(input, dtype=None, name=None):  # pylint: disable=redefined-builtin
286  """Replace every object with a zero.
287
288  Example:
289  >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}])
290  >>> tf.ones_like(st)
291  <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1.0, 1.0], dtype=float32)>
292  >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]])
293  >>> tf.ones_like(st, dtype=tf.int32)
294  <tf.RaggedTensor [[1], [1, 1]]>
295
296  Args:
297    input: a structured tensor.
298    dtype: the dtype of the resulting zeros. (default is tf.float32)
299    name: a name for the op.
300  Returns:
301    a tensor of zeros of the same shape.
302  """
303  if dtype is None:
304    dtype = dtypes.float32
305  with ops.name_scope(name, 'ones_like', [input]) as name:
306    if not input.row_partitions:
307      if input.nrows() is not None:
308        return array_ops.ones([input.nrows()], dtype)  # vector.
309      else:
310        return array_ops.ones([], dtype)  # scalar.
311    # 2D and up.
312    last_row_partition = input.row_partitions[-1]
313
314    result = ragged_tensor.RaggedTensor._from_nested_row_partitions(
315        array_ops.ones(last_row_partition.nvals(), dtype=dtype),
316        input.row_partitions)
317    return result
318
319
320@dispatch.dispatch_for_types(array_ops.rank, StructuredTensor)
321def rank(input, name=None):
322  # pylint: disable=redefined-builtin
323  """Returns the rank of a tensor."""
324  with ops.name_scope(name, 'rank', [input]) as name:
325    return constant_op.constant(input.rank, dtype=dtypes.int32)
326
327
328def _expand_dims_impl(st, axis, name=None):  # pylint: disable=redefined-builtin
329  """Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
330
331  This is an implementation of tf.expand_dims for StructuredTensor. Note
332  that the `axis` must be less than or equal to rank.
333
334  >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
335  >>> tf.expand_dims(st, 0).to_pyval()
336  [[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
337  >>> tf.expand_dims(st, 1).to_pyval()
338  [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
339  >>> tf.expand_dims(st, 2).to_pyval()
340  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
341  >>> tf.expand_dims(st, -1).to_pyval()  # -1 is the same as 2
342  [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
343
344  Args:
345    st: the original StructuredTensor.
346    axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
347    name: the name of the op.
348
349  Returns:
350    a new structured tensor with larger rank.
351
352  Raises:
353    an error if `axis < -(rank + 1)` or `rank < axis`.
354  """
355  axis = array_ops.get_positive_axis(
356      axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)')
357  with ops.name_scope(name, 'ExpandDims', [st, axis]):
358    new_fields = {
359        k: array_ops.expand_dims(v, axis) for (k, v) in st._fields.items()
360    }
361    new_shape = st.shape[:axis] + (1,) + st.shape[axis:]
362    new_row_partitions = _expand_st_row_partitions(st, axis)
363    new_nrows = st.nrows() if (axis > 0) else 1
364    return StructuredTensor.from_fields(
365        new_fields,
366        shape=new_shape,
367        row_partitions=new_row_partitions,
368        nrows=new_nrows)
369
370
371def _expand_st_row_partitions(st, axis):
372  """Create the row_partitions for expand_dims."""
373  if axis == 0:
374    if st.shape.rank == 0:
375      return ()
376    nvals = st.nrows()
377    new_partition = RowPartition.from_uniform_row_length(
378        nvals, nvals, nrows=1, validate=False)
379    return (new_partition,) + st.row_partitions
380  elif axis == st.rank:
381    nvals = (
382        st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows())
383    return st.row_partitions + (RowPartition.from_uniform_row_length(
384        1, nvals, nrows=nvals, validate=False),)
385  else:
386    nvals = (
387        st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows())
388    return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length(
389        1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:]
390
391
392# TODO(martinz): consider allowing values to be nested.
393def _extend_op(values, leaf_op, empty_st_op=None):
394  """Extend an op from RaggedTensor and Tensor to StructuredTensor.
395
396  Visits all children of the structured tensor, and children of children,
397  applying leaf_op whenever it reaches a leaf, and empty_st_op whenever
398  it reaches an internal node without children.
399
400  Args:
401    values: a list of structured tensors, ragged tensors, or tensors. All must
402      have the same type. If they are structured tensors, they must have the
403      same paths.
404    leaf_op: an op for handling non-structured tensor.
405    empty_st_op: op to create a structured tensor without fields.
406
407  Returns:
408    the result of the extended op (a StructuredTensor, RaggedTensor, or Tensor)
409
410  Raises:
411    ValueError:
412      If values is not a Sequence or is empty.
413  """
414  if not isinstance(values, Sequence):
415    raise ValueError('Expected a list')
416
417  if not values:
418    raise ValueError('List cannot be empty')
419
420  if empty_st_op is None:
421    empty_st_op = empty_st_op_like_zeros(leaf_op)
422  # Use the structure of the first StructuredTensor. They are all assumed to
423  # be the same.
424  value = values[0]
425
426  if isinstance(value, StructuredTensor):
427    # TODO(martinz): Calling empty_st_op may add unnecessary ops. Revisit later.
428    empty_result = empty_st_op(values)
429    if not value.field_names():
430      return empty_result
431    new_fields = {}
432    for k in value.field_names():
433      new_fields[k] = _extend_op([v.field_value(k) for v in values], leaf_op,
434                                 empty_st_op)
435    return StructuredTensor.from_fields(new_fields, shape=empty_result.shape)
436  else:
437    return leaf_op(values)
438
439
440def _extend_op_single(value, leaf_op, empty_st_op=None):
441  """Extend an op to a value instead of a list of values."""
442
443  def to_list_op(element_op):
444    if element_op is None:
445      return None
446
447    def list_op(values):
448      [value] = values
449      return element_op(value)
450
451    return list_op
452
453  return _extend_op([value], to_list_op(leaf_op), to_list_op(empty_st_op))
454
455
456def empty_st_op_like_zeros(leaf_op):
457
458  def empty_st_op(values):
459    as_zeros = [
460        zeros_like_v2(value, dtype=dtypes.int32) for value in values
461    ]
462    result = leaf_op(as_zeros)
463    return _structured_tensor_like(result)
464
465  return empty_st_op
466
467
468def _structured_tensor_from_dense_tensor(t):
469  """Create a structured tensor with the shape of a dense tensor."""
470  # Note: If a tensor will have rank 0,
471  # it either has a fully defined shape or has unknown rank.
472  if t.shape.is_fully_defined():
473    return StructuredTensor.from_fields({}, shape=t.shape)
474  elif t.shape.rank is None:
475    raise ValueError("Can't build StructuredTensor w/ unknown rank")
476  elif t.shape.rank == 1:
477    return StructuredTensor.from_fields({}, shape=t.shape,
478                                        nrows=array_ops.shape(t)[0])
479  else:
480    rt = ragged_tensor.RaggedTensor.from_tensor(t)
481    return _structured_tensor_from_row_partitions(t.shape,
482                                                  rt._nested_row_partitions)
483
484
485def _structured_tensor_from_row_partitions(shape, row_partitions):
486  return StructuredTensor.from_fields({},
487                                      shape=shape,
488                                      row_partitions=row_partitions)
489
490
491# pylint: disable=protected_access
492def _all_nested_row_partitions(rt):
493  """Returns all nested row partitions in rt, including for dense dimensions."""
494  if isinstance(rt, ops.Tensor):
495    if rt.shape.rank <= 1:
496      return ()
497    else:
498      rt2 = ragged_tensor.RaggedTensor.from_tensor(rt)
499      return rt2._nested_row_partitions
500  else:
501    tail_partitions = _all_nested_row_partitions(rt.flat_values)
502    head_partitions = rt._nested_row_partitions  # pylint: disable=protected_access
503    return head_partitions + tail_partitions
504
505
506def _structured_tensor_like(t):
507  """Create a StructuredTensor with the shape of a (composite) tensor."""
508  if isinstance(t, ops.Tensor):
509    return _structured_tensor_from_dense_tensor(t)
510  if ragged_tensor.is_ragged(t):
511    return StructuredTensor.from_fields(
512        {}, shape=t.get_shape(), row_partitions=_all_nested_row_partitions(t))
513  # here, it is a StructuredTensor
514  return StructuredTensor.from_fields({},
515                                      shape=t.shape,
516                                      row_partitions=t.row_partitions,
517                                      nrows=t.nrows())
518
519
520def _get_all_paths(st):
521  """Get all the paths from a StructuredTensor."""
522  fields = st.field_names()
523  all_paths = {()}
524  for k in fields:
525    v = st.field_value(k)
526    if isinstance(v, StructuredTensor):
527      all_paths = all_paths.union([(k,) + p for p in _get_all_paths(v)])
528    else:
529      all_paths.add((k,))
530  return all_paths
531
532
533def _get_all_ranks(st):
534  """Get ranks of all submessages of a StructuredTensor."""
535  fields = st.field_names()
536  all_ranks = {(): st.rank}
537  for k in fields:
538    v = st.field_value(k)
539    if isinstance(v, StructuredTensor):
540      for (k2, v2) in _get_all_ranks(v).items():
541        all_ranks[(k,) + k2] = v2
542  return all_ranks
543
544
545def _assert_all_paths_match(values):
546  """Raises an error if the paths are not identical."""
547  paths = [_get_all_paths(st) for st in values]
548  path_diff = set()
549  for other_paths in paths[1:]:
550    path_diff = path_diff.union(paths[0].symmetric_difference(other_paths))
551  if path_diff:
552    raise ValueError(
553        'Some paths are present in some, but not all, structured tensors: %r' %
554        (path_diff,))
555
556
557def _assert_all_ranks_match(values):
558  """Raises an error if the ranks of submessages are not identical."""
559  ranks = [_get_all_ranks(st) for st in values]
560  for other_ranks in ranks[1:]:
561    if other_ranks != ranks[0]:
562      # TODO(martinz): If this becomes common, we can provide more detail.
563      # e.g.: which path is inconsistent.
564      raise ValueError('Ranks of sub-message do not match')
565
566
567def _assert_concat_compatible_structured_tensors(values):
568  """Sometimes raises an error if concat doesn't make sense statically on values.
569
570  values must be a sequence, and each element in values must be a structured
571  tensor, and must have the same paths. Additionally, each path that is a
572  submessage must have the same rank.
573
574  These constraints are sufficient for concat on the fields to be the same
575  as concat on structured tensors. This is meant to capture scenarios like
576  paths that are not in the first structured tensor, but are in later
577  structured tensors, which will just be ignored by the recursive algorithm.
578
579  If the rank of a submessage was different for two structured tensors,
580  then that is also a non-sensical merge.
581
582  Note that all of these checks are static, as paths and submessage ranks
583  are known.
584
585  Args:
586    values: a Sequence of StructuredTensors.
587
588  Raises:
589    ValueError: if there is any inconsistency as described above.
590  """
591  if not isinstance(values, Sequence):
592    raise ValueError('values must be a list of StructuredTensors (not a list)')
593  if not values:
594    raise ValueError('values must not be an empty list')
595  for st in values:
596    if not isinstance(st, StructuredTensor):
597      raise ValueError('values must be a list of StructuredTensors')
598  _assert_all_paths_match(values)
599  _assert_all_ranks_match(values)
600