xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/sparse_tensor.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Sparse tensors."""
16# pylint: disable=g-bad-name
17import collections
18
19import numpy as np
20
21from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
22from tensorflow.python import tf2
23from tensorflow.python.framework import composite_tensor
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.framework import type_spec
31from tensorflow.python.ops import gen_sparse_ops
32from tensorflow.python.types import internal
33from tensorflow.python.util import _pywrap_utils
34from tensorflow.python.util.tf_export import tf_export
35
36# pylint: disable=protected-access
37_eval_using_default_session = ops._eval_using_default_session
38_override_helper = ops._override_helper
39# pylint: enable=protected-access
40
41
42@tf_export("sparse.SparseTensor", "SparseTensor")
43class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor):
44  """Represents a sparse tensor.
45
46  TensorFlow represents a sparse tensor as three separate dense tensors:
47  `indices`, `values`, and `dense_shape`.  In Python, the three tensors are
48  collected into a `SparseTensor` class for ease of use.  If you have separate
49  `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor`
50  object before passing to the ops below.
51
52  Concretely, the sparse tensor `SparseTensor(indices, values, dense_shape)`
53  comprises the following components, where `N` and `ndims` are the number
54  of values and number of dimensions in the `SparseTensor`, respectively:
55
56  * `indices`: A 2-D int64 tensor of shape `[N, ndims]`, which specifies the
57    indices of the elements in the sparse tensor that contain nonzero values
58    (elements are zero-indexed). For example, `indices=[[1,3], [2,4]]` specifies
59    that the elements with indexes of [1,3] and [2,4] have nonzero values.
60
61  * `values`: A 1-D tensor of any type and shape `[N]`, which supplies the
62    values for each element in `indices`. For example, given `indices=[[1,3],
63    [2,4]]`, the parameter `values=[18, 3.6]` specifies that element [1,3] of
64    the sparse tensor has a value of 18, and element [2,4] of the tensor has a
65    value of 3.6.
66
67  * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`, which specifies the
68    dense_shape of the sparse tensor. Takes a list indicating the number of
69    elements in each dimension. For example, `dense_shape=[3,6]` specifies a
70    two-dimensional 3x6 tensor, `dense_shape=[2,3,4]` specifies a
71    three-dimensional 2x3x4 tensor, and `dense_shape=[9]` specifies a
72    one-dimensional tensor with 9 elements.
73
74  The corresponding dense tensor satisfies:
75
76  ```python
77  dense.shape = dense_shape
78  dense[tuple(indices[i])] = values[i]
79  ```
80
81  By convention, `indices` should be sorted in row-major order (or equivalently
82  lexicographic order on the tuples `indices[i]`). This is not enforced when
83  `SparseTensor` objects are constructed, but most ops assume correct ordering.
84  If the ordering of sparse tensor `st` is wrong, a fixed version can be
85  obtained by calling `tf.sparse.reorder(st)`.
86
87  Example: The sparse tensor
88
89  ```python
90  SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
91  ```
92
93  represents the dense tensor
94
95  ```python
96  [[1, 0, 0, 0]
97   [0, 0, 2, 0]
98   [0, 0, 0, 0]]
99  ```
100  """
101
102  @classmethod
103  def from_value(cls, sparse_tensor_value):
104    if not is_sparse(sparse_tensor_value):
105      raise TypeError(f"Argument sparse_tensor_value={sparse_tensor_value} "
106                      "is neither a SparseTensor nor SparseTensorValue.")
107    return SparseTensor(
108        indices=sparse_tensor_value.indices,
109        values=sparse_tensor_value.values,
110        dense_shape=sparse_tensor_value.dense_shape)
111
112  def __init__(self, indices, values, dense_shape):
113    """Creates a `SparseTensor`.
114
115    Args:
116      indices: A 2-D int64 tensor of shape `[N, ndims]`.
117      values: A 1-D tensor of any type and shape `[N]`.
118      dense_shape: A 1-D int64 tensor of shape `[ndims]`.
119
120    Raises:
121      ValueError: When building an eager SparseTensor if `dense_shape` is
122        unknown or contains unknown elements (None or -1).
123    """
124    with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]):
125      indices = ops.convert_to_tensor(
126          indices, name="indices", dtype=dtypes.int64)
127      # TODO(touts): Consider adding mutable_values() when 'values'
128      # is a VariableOp and updating users of SparseTensor.
129      values = ops.convert_to_tensor(values, name="values")
130
131      dense_shape = ops.convert_to_tensor(
132          dense_shape, name="dense_shape", dtype=dtypes.int64)
133      dense_shape_default = tensor_util.constant_value_as_shape(dense_shape)
134
135    self._indices = indices
136    self._values = values
137    self._dense_shape = dense_shape
138    self._dense_shape_default = dense_shape_default
139
140    indices_shape = indices.shape.with_rank(2)
141    values_shape = values.shape.with_rank(1)
142    dense_shape_shape = dense_shape.shape.with_rank(1)
143
144    # Assert number of rows in indices match the number of elements in values.
145    indices_shape.dims[0].assert_is_compatible_with(values_shape.dims[0])
146    # Assert number of columns in indices matches the number of elements in
147    # dense_shape.
148    indices_shape.dims[1].assert_is_compatible_with(dense_shape_shape.dims[0])
149
150  def get_shape(self):
151    """Get the `TensorShape` representing the shape of the dense tensor.
152
153    Returns:
154      A `TensorShape` object.
155    """
156    return self._dense_shape_default
157
158  @property
159  def indices(self):
160    """The indices of non-zero values in the represented dense tensor.
161
162    Returns:
163      A 2-D Tensor of int64 with dense_shape `[N, ndims]`, where `N` is the
164        number of non-zero values in the tensor, and `ndims` is the rank.
165    """
166    return self._indices
167
168  @property
169  def values(self):
170    """The non-zero values in the represented dense tensor.
171
172    Returns:
173      A 1-D Tensor of any data type.
174    """
175    return self._values
176
177  def with_values(self, new_values):
178    """Returns a copy of `self` with `values` replaced by `new_values`.
179
180    This method produces a new `SparseTensor` that has the same nonzero
181    `indices` and same `dense_shape`, but updated values.
182
183    Args:
184      new_values: The values of the new `SparseTensor`. Needs to have the same
185        shape as the current `.values` `Tensor`. May have a different type than
186        the current `values`.
187
188    Returns:
189      A `SparseTensor` with identical indices and shape but updated values.
190
191    Example usage:
192
193    >>> st = tf.sparse.from_dense([[1, 0, 2, 0], [3, 0, 0, 4]])
194    >>> tf.sparse.to_dense(st.with_values([10, 20, 30, 40]))  # 4 nonzero values
195    <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
196    array([[10,  0, 20,  0],
197           [30,  0,  0, 40]], dtype=int32)>
198
199    """
200    return SparseTensor(self._indices, new_values, self._dense_shape)
201
202  @property
203  def op(self):
204    """The `Operation` that produces `values` as an output."""
205    return self._values.op
206
207  @property
208  def dtype(self):
209    """The `DType` of elements in this tensor."""
210    return self._values.dtype
211
212  @property
213  def dense_shape(self):
214    """A 1-D Tensor of int64 representing the shape of the dense tensor."""
215    return self._dense_shape
216
217  @property
218  def shape(self):
219    """Get the `TensorShape` representing the shape of the dense tensor.
220
221    Returns:
222      A `TensorShape` object.
223    """
224    return self._dense_shape_default
225
226  @property
227  def graph(self):
228    """The `Graph` that contains the index, value, and dense_shape tensors."""
229    return self._indices.graph
230
231  def __str__(self):
232    return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % (
233        self._indices, self._values, self._dense_shape)
234
235  def eval(self, feed_dict=None, session=None):
236    """Evaluates this sparse tensor in a `Session`.
237
238    Calling this method will execute all preceding operations that
239    produce the inputs needed for the operation that produces this
240    tensor.
241
242    *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been
243    launched in a session, and either a default session must be
244    available, or `session` must be specified explicitly.
245
246    Args:
247      feed_dict: A dictionary that maps `Tensor` objects to feed values. See
248        `tf.Session.run` for a description of the valid feed values.
249      session: (Optional.) The `Session` to be used to evaluate this sparse
250        tensor. If none, the default session will be used.
251
252    Returns:
253      A `SparseTensorValue` object.
254    """
255    indices, values, dense_shape = _eval_using_default_session(
256        [self.indices, self.values, self.dense_shape], feed_dict, self.graph,
257        session)
258    return SparseTensorValue(indices, values, dense_shape)
259
260  @staticmethod
261  def _override_operator(operator, func):
262    _override_helper(SparseTensor, operator, func)
263
264  @property
265  def _type_spec(self):
266    return SparseTensorSpec(self.shape, self.dtype)
267
268  def _shape_invariant_to_type_spec(self, shape):
269    # From the tf.while_loop docs: "If a loop variable is a SparseTensor, the
270    # shape invariant must be TensorShape([r]) where r is the rank of the dense
271    # tensor represented by the sparse tensor. It means the shapes of the three
272    # tensors of the SparseTensor are ([None], [None, r], [r]). NOTE: The shape
273    # invariant here is the shape of the SparseTensor.dense_shape property. It
274    # must be the shape of a vector.
275    if shape.ndims is not None and shape.ndims != 1:
276      raise ValueError(f"Expected a shape with 1 dimension. Obtained: {shape} "
277                       f"which has {shape.ndims} dimensions.")
278    rank = tensor_shape.dimension_value(shape[0])
279    return SparseTensorSpec(tensor_shape.unknown_shape(rank), self.dtype)
280
281  def consumers(self):
282    return self._consumers()
283
284
285SparseTensorValue = collections.namedtuple("SparseTensorValue",
286                                           ["indices", "values", "dense_shape"])
287tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
288_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue)
289
290
291@tf_export("SparseTensorSpec")
292@type_spec.register("tf.SparseTensorSpec")
293class SparseTensorSpec(type_spec.BatchableTypeSpec):
294  """Type specification for a `tf.sparse.SparseTensor`."""
295
296  __slots__ = ["_shape", "_dtype"]
297
298  value_type = property(lambda self: SparseTensor)
299
300  def __init__(self, shape=None, dtype=dtypes.float32):
301    """Constructs a type specification for a `tf.sparse.SparseTensor`.
302
303    Args:
304      shape: The dense shape of the `SparseTensor`, or `None` to allow any dense
305        shape.
306      dtype: `tf.DType` of values in the `SparseTensor`.
307    """
308    self._shape = tensor_shape.as_shape(shape)
309    self._dtype = dtypes.as_dtype(dtype)
310
311  def _serialize(self):
312    return (self._shape, self._dtype)
313
314  @property
315  def dtype(self):
316    """The `tf.dtypes.DType` specified by this type for the SparseTensor."""
317    return self._dtype
318
319  @property
320  def shape(self):
321    """The `tf.TensorShape` specified by this type for the SparseTensor."""
322    return self._shape
323
324  @property
325  def _component_specs(self):
326    rank = self._shape.ndims
327    num_values = None
328    return [
329        tensor_spec.TensorSpec([num_values, rank], dtypes.int64),
330        tensor_spec.TensorSpec([num_values], self._dtype),
331        tensor_spec.TensorSpec([rank], dtypes.int64)]
332
333  def _to_components(self, value):
334    if isinstance(value, SparseTensorValue):
335      value = SparseTensor.from_value(value)
336    return [value.indices, value.values, value.dense_shape]
337
338  def _from_components(self, tensor_list):
339    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
340        not tf2.enabled()):
341      return SparseTensorValue(*tensor_list)
342    else:
343      return SparseTensor(*tensor_list)
344
345  # The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
346  # to (un)box the component tensors in a way that allows for batching &
347  # unbatching.
348  @property
349  def _flat_tensor_specs(self):
350    # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
351    # but a `SparseTensorSpec` can also represent a batch of boxed
352    # `SparseTensor` objects with shape `(..., 3)` (and batches of batches,
353    # etc.), so the flat shape must be unknown.
354    return [tensor_spec.TensorSpec(None, dtypes.variant)]
355
356  def _to_tensor_list(self, value):
357    value = SparseTensor.from_value(value)
358    return [gen_sparse_ops.serialize_sparse(
359        value.indices, value.values, value.dense_shape,
360        out_type=dtypes.variant)]
361
362  def _to_batched_tensor_list(self, value):
363    dense_shape = tensor_util.constant_value_as_shape(value.dense_shape)
364    if self._shape.merge_with(dense_shape).ndims == 0:
365      raise ValueError(
366          "Unbatching a sparse tensor is only supported for rank >= 1. "
367          f"Obtained input: {value}.")
368    return [gen_sparse_ops.serialize_many_sparse(
369        value.indices, value.values, value.dense_shape,
370        out_type=dtypes.variant)]
371
372  def _from_compatible_tensor_list(self, tensor_list):
373    tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
374    indices, values, dense_shape = tensor_list
375    rank = self._shape.ndims
376    indices.set_shape([None, rank])
377    # We restore the dense_shape from the SparseTypeSpec. This is necessary
378    # for shape inference when using placeholder SparseTensors in function
379    # tracing.
380    if self._shape.is_fully_defined():
381      dense_shape = ops.convert_to_tensor(
382          self._shape, dtype=dtypes.int64, name="shape")
383    elif (self._shape.rank is not None and
384          any(dim.value is not None for dim in self._shape.dims)):
385      # array_ops imports sparse_tensor.py. Local import to avoid import cycle.
386      from tensorflow.python.ops import array_ops  # pylint: disable=g-import-not-at-top
387      pieces = array_ops.unstack(dense_shape, num=self._shape.rank)
388      for i, dim in enumerate(self._shape.dims):
389        if dim.value is not None:
390          pieces[i] = constant_op.constant(dim.value, dense_shape.dtype)
391      dense_shape = array_ops.stack(pieces)
392    else:
393      dense_shape.set_shape([rank])
394
395    return SparseTensor(indices, values, dense_shape)
396
397  def _batch(self, batch_size):
398    return SparseTensorSpec(
399        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
400        self._dtype)
401
402  def _unbatch(self):
403    if self._shape.ndims == 0:
404      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
405    return SparseTensorSpec(self._shape[1:], self._dtype)
406
407  def _to_legacy_output_types(self):
408    return self._dtype
409
410  def _to_legacy_output_shapes(self):
411    return self._shape
412
413  def _to_legacy_output_classes(self):
414    return SparseTensor
415
416  @classmethod
417  def from_value(cls, value):
418    if isinstance(value, SparseTensor):
419      return cls(value.shape, value.dtype)
420    if isinstance(value, SparseTensorValue):
421      if isinstance(value.values, np.ndarray):
422        return cls(value.dense_shape, value.values.dtype)
423      else:
424        return cls.from_value(SparseTensor.from_value(value))
425    else:
426      raise TypeError("Expected SparseTensor or SparseTensorValue. Received: "
427                      f"{value} of type {type(value).__name__}.")
428
429
430# TODO(b/133606651) Delete the SparseTensor registration when CompositeTensor
431# is updated to define a _type_spec field (since registration will be
432# automatic).  Do *not* delete the SparseTensorValue registration.
433type_spec.register_type_spec_from_value_converter(
434    SparseTensor, SparseTensorSpec.from_value)
435type_spec.register_type_spec_from_value_converter(
436    SparseTensorValue, SparseTensorSpec.from_value)
437
438
439@tf_export(v1=["convert_to_tensor_or_sparse_tensor"])
440def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
441  """Converts value to a `SparseTensor` or `Tensor`.
442
443  Args:
444    value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a
445      registered `Tensor` conversion function.
446    dtype: Optional element type for the returned tensor. If missing, the type
447      is inferred from the type of `value`.
448    name: Optional name to use if a new `Tensor` is created.
449
450  Returns:
451    A `SparseTensor` or `Tensor` based on `value`.
452
453  Raises:
454    RuntimeError: If result type is incompatible with `dtype`.
455  """
456  if dtype is not None:
457    dtype = dtypes.as_dtype(dtype)
458  if isinstance(value, SparseTensorValue):
459    value = SparseTensor.from_value(value)
460  if isinstance(value, SparseTensor):
461    if dtype and not dtype.is_compatible_with(value.dtype):
462      raise RuntimeError(f"Sparse dtype mismatch. Requested: {dtype.name}, "
463                         f" Actual: {value.dtype.name}")
464    return value
465  return ops.convert_to_tensor(value, dtype=dtype, name=name)
466
467
468def is_sparse(x):
469  """Check whether `x` is sparse.
470
471  Check whether an object is a `tf.sparse.SparseTensor` or
472  `tf.compat.v1.SparseTensorValue`.
473
474  Args:
475    x: A python object to check.
476
477  Returns:
478    `True` iff `x` is a `tf.sparse.SparseTensor` or
479    `tf.compat.v1.SparseTensorValue`.
480  """
481  return isinstance(x, (SparseTensor, SparseTensorValue))
482