1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Sparse tensors.""" 16# pylint: disable=g-bad-name 17import collections 18 19import numpy as np 20 21from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import 22from tensorflow.python import tf2 23from tensorflow.python.framework import composite_tensor 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.framework import type_spec 31from tensorflow.python.ops import gen_sparse_ops 32from tensorflow.python.types import internal 33from tensorflow.python.util import _pywrap_utils 34from tensorflow.python.util.tf_export import tf_export 35 36# pylint: disable=protected-access 37_eval_using_default_session = ops._eval_using_default_session 38_override_helper = ops._override_helper 39# pylint: enable=protected-access 40 41 42@tf_export("sparse.SparseTensor", "SparseTensor") 43class SparseTensor(internal.NativeObject, composite_tensor.CompositeTensor): 44 """Represents a sparse tensor. 45 46 TensorFlow represents a sparse tensor as three separate dense tensors: 47 `indices`, `values`, and `dense_shape`. In Python, the three tensors are 48 collected into a `SparseTensor` class for ease of use. If you have separate 49 `indices`, `values`, and `dense_shape` tensors, wrap them in a `SparseTensor` 50 object before passing to the ops below. 51 52 Concretely, the sparse tensor `SparseTensor(indices, values, dense_shape)` 53 comprises the following components, where `N` and `ndims` are the number 54 of values and number of dimensions in the `SparseTensor`, respectively: 55 56 * `indices`: A 2-D int64 tensor of shape `[N, ndims]`, which specifies the 57 indices of the elements in the sparse tensor that contain nonzero values 58 (elements are zero-indexed). For example, `indices=[[1,3], [2,4]]` specifies 59 that the elements with indexes of [1,3] and [2,4] have nonzero values. 60 61 * `values`: A 1-D tensor of any type and shape `[N]`, which supplies the 62 values for each element in `indices`. For example, given `indices=[[1,3], 63 [2,4]]`, the parameter `values=[18, 3.6]` specifies that element [1,3] of 64 the sparse tensor has a value of 18, and element [2,4] of the tensor has a 65 value of 3.6. 66 67 * `dense_shape`: A 1-D int64 tensor of shape `[ndims]`, which specifies the 68 dense_shape of the sparse tensor. Takes a list indicating the number of 69 elements in each dimension. For example, `dense_shape=[3,6]` specifies a 70 two-dimensional 3x6 tensor, `dense_shape=[2,3,4]` specifies a 71 three-dimensional 2x3x4 tensor, and `dense_shape=[9]` specifies a 72 one-dimensional tensor with 9 elements. 73 74 The corresponding dense tensor satisfies: 75 76 ```python 77 dense.shape = dense_shape 78 dense[tuple(indices[i])] = values[i] 79 ``` 80 81 By convention, `indices` should be sorted in row-major order (or equivalently 82 lexicographic order on the tuples `indices[i]`). This is not enforced when 83 `SparseTensor` objects are constructed, but most ops assume correct ordering. 84 If the ordering of sparse tensor `st` is wrong, a fixed version can be 85 obtained by calling `tf.sparse.reorder(st)`. 86 87 Example: The sparse tensor 88 89 ```python 90 SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 91 ``` 92 93 represents the dense tensor 94 95 ```python 96 [[1, 0, 0, 0] 97 [0, 0, 2, 0] 98 [0, 0, 0, 0]] 99 ``` 100 """ 101 102 @classmethod 103 def from_value(cls, sparse_tensor_value): 104 if not is_sparse(sparse_tensor_value): 105 raise TypeError(f"Argument sparse_tensor_value={sparse_tensor_value} " 106 "is neither a SparseTensor nor SparseTensorValue.") 107 return SparseTensor( 108 indices=sparse_tensor_value.indices, 109 values=sparse_tensor_value.values, 110 dense_shape=sparse_tensor_value.dense_shape) 111 112 def __init__(self, indices, values, dense_shape): 113 """Creates a `SparseTensor`. 114 115 Args: 116 indices: A 2-D int64 tensor of shape `[N, ndims]`. 117 values: A 1-D tensor of any type and shape `[N]`. 118 dense_shape: A 1-D int64 tensor of shape `[ndims]`. 119 120 Raises: 121 ValueError: When building an eager SparseTensor if `dense_shape` is 122 unknown or contains unknown elements (None or -1). 123 """ 124 with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]): 125 indices = ops.convert_to_tensor( 126 indices, name="indices", dtype=dtypes.int64) 127 # TODO(touts): Consider adding mutable_values() when 'values' 128 # is a VariableOp and updating users of SparseTensor. 129 values = ops.convert_to_tensor(values, name="values") 130 131 dense_shape = ops.convert_to_tensor( 132 dense_shape, name="dense_shape", dtype=dtypes.int64) 133 dense_shape_default = tensor_util.constant_value_as_shape(dense_shape) 134 135 self._indices = indices 136 self._values = values 137 self._dense_shape = dense_shape 138 self._dense_shape_default = dense_shape_default 139 140 indices_shape = indices.shape.with_rank(2) 141 values_shape = values.shape.with_rank(1) 142 dense_shape_shape = dense_shape.shape.with_rank(1) 143 144 # Assert number of rows in indices match the number of elements in values. 145 indices_shape.dims[0].assert_is_compatible_with(values_shape.dims[0]) 146 # Assert number of columns in indices matches the number of elements in 147 # dense_shape. 148 indices_shape.dims[1].assert_is_compatible_with(dense_shape_shape.dims[0]) 149 150 def get_shape(self): 151 """Get the `TensorShape` representing the shape of the dense tensor. 152 153 Returns: 154 A `TensorShape` object. 155 """ 156 return self._dense_shape_default 157 158 @property 159 def indices(self): 160 """The indices of non-zero values in the represented dense tensor. 161 162 Returns: 163 A 2-D Tensor of int64 with dense_shape `[N, ndims]`, where `N` is the 164 number of non-zero values in the tensor, and `ndims` is the rank. 165 """ 166 return self._indices 167 168 @property 169 def values(self): 170 """The non-zero values in the represented dense tensor. 171 172 Returns: 173 A 1-D Tensor of any data type. 174 """ 175 return self._values 176 177 def with_values(self, new_values): 178 """Returns a copy of `self` with `values` replaced by `new_values`. 179 180 This method produces a new `SparseTensor` that has the same nonzero 181 `indices` and same `dense_shape`, but updated values. 182 183 Args: 184 new_values: The values of the new `SparseTensor`. Needs to have the same 185 shape as the current `.values` `Tensor`. May have a different type than 186 the current `values`. 187 188 Returns: 189 A `SparseTensor` with identical indices and shape but updated values. 190 191 Example usage: 192 193 >>> st = tf.sparse.from_dense([[1, 0, 2, 0], [3, 0, 0, 4]]) 194 >>> tf.sparse.to_dense(st.with_values([10, 20, 30, 40])) # 4 nonzero values 195 <tf.Tensor: shape=(2, 4), dtype=int32, numpy= 196 array([[10, 0, 20, 0], 197 [30, 0, 0, 40]], dtype=int32)> 198 199 """ 200 return SparseTensor(self._indices, new_values, self._dense_shape) 201 202 @property 203 def op(self): 204 """The `Operation` that produces `values` as an output.""" 205 return self._values.op 206 207 @property 208 def dtype(self): 209 """The `DType` of elements in this tensor.""" 210 return self._values.dtype 211 212 @property 213 def dense_shape(self): 214 """A 1-D Tensor of int64 representing the shape of the dense tensor.""" 215 return self._dense_shape 216 217 @property 218 def shape(self): 219 """Get the `TensorShape` representing the shape of the dense tensor. 220 221 Returns: 222 A `TensorShape` object. 223 """ 224 return self._dense_shape_default 225 226 @property 227 def graph(self): 228 """The `Graph` that contains the index, value, and dense_shape tensors.""" 229 return self._indices.graph 230 231 def __str__(self): 232 return "SparseTensor(indices=%s, values=%s, dense_shape=%s)" % ( 233 self._indices, self._values, self._dense_shape) 234 235 def eval(self, feed_dict=None, session=None): 236 """Evaluates this sparse tensor in a `Session`. 237 238 Calling this method will execute all preceding operations that 239 produce the inputs needed for the operation that produces this 240 tensor. 241 242 *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been 243 launched in a session, and either a default session must be 244 available, or `session` must be specified explicitly. 245 246 Args: 247 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 248 `tf.Session.run` for a description of the valid feed values. 249 session: (Optional.) The `Session` to be used to evaluate this sparse 250 tensor. If none, the default session will be used. 251 252 Returns: 253 A `SparseTensorValue` object. 254 """ 255 indices, values, dense_shape = _eval_using_default_session( 256 [self.indices, self.values, self.dense_shape], feed_dict, self.graph, 257 session) 258 return SparseTensorValue(indices, values, dense_shape) 259 260 @staticmethod 261 def _override_operator(operator, func): 262 _override_helper(SparseTensor, operator, func) 263 264 @property 265 def _type_spec(self): 266 return SparseTensorSpec(self.shape, self.dtype) 267 268 def _shape_invariant_to_type_spec(self, shape): 269 # From the tf.while_loop docs: "If a loop variable is a SparseTensor, the 270 # shape invariant must be TensorShape([r]) where r is the rank of the dense 271 # tensor represented by the sparse tensor. It means the shapes of the three 272 # tensors of the SparseTensor are ([None], [None, r], [r]). NOTE: The shape 273 # invariant here is the shape of the SparseTensor.dense_shape property. It 274 # must be the shape of a vector. 275 if shape.ndims is not None and shape.ndims != 1: 276 raise ValueError(f"Expected a shape with 1 dimension. Obtained: {shape} " 277 f"which has {shape.ndims} dimensions.") 278 rank = tensor_shape.dimension_value(shape[0]) 279 return SparseTensorSpec(tensor_shape.unknown_shape(rank), self.dtype) 280 281 def consumers(self): 282 return self._consumers() 283 284 285SparseTensorValue = collections.namedtuple("SparseTensorValue", 286 ["indices", "values", "dense_shape"]) 287tf_export(v1=["SparseTensorValue"])(SparseTensorValue) 288_pywrap_utils.RegisterType("SparseTensorValue", SparseTensorValue) 289 290 291@tf_export("SparseTensorSpec") 292@type_spec.register("tf.SparseTensorSpec") 293class SparseTensorSpec(type_spec.BatchableTypeSpec): 294 """Type specification for a `tf.sparse.SparseTensor`.""" 295 296 __slots__ = ["_shape", "_dtype"] 297 298 value_type = property(lambda self: SparseTensor) 299 300 def __init__(self, shape=None, dtype=dtypes.float32): 301 """Constructs a type specification for a `tf.sparse.SparseTensor`. 302 303 Args: 304 shape: The dense shape of the `SparseTensor`, or `None` to allow any dense 305 shape. 306 dtype: `tf.DType` of values in the `SparseTensor`. 307 """ 308 self._shape = tensor_shape.as_shape(shape) 309 self._dtype = dtypes.as_dtype(dtype) 310 311 def _serialize(self): 312 return (self._shape, self._dtype) 313 314 @property 315 def dtype(self): 316 """The `tf.dtypes.DType` specified by this type for the SparseTensor.""" 317 return self._dtype 318 319 @property 320 def shape(self): 321 """The `tf.TensorShape` specified by this type for the SparseTensor.""" 322 return self._shape 323 324 @property 325 def _component_specs(self): 326 rank = self._shape.ndims 327 num_values = None 328 return [ 329 tensor_spec.TensorSpec([num_values, rank], dtypes.int64), 330 tensor_spec.TensorSpec([num_values], self._dtype), 331 tensor_spec.TensorSpec([rank], dtypes.int64)] 332 333 def _to_components(self, value): 334 if isinstance(value, SparseTensorValue): 335 value = SparseTensor.from_value(value) 336 return [value.indices, value.values, value.dense_shape] 337 338 def _from_components(self, tensor_list): 339 if (all(isinstance(t, np.ndarray) for t in tensor_list) and 340 not tf2.enabled()): 341 return SparseTensorValue(*tensor_list) 342 else: 343 return SparseTensor(*tensor_list) 344 345 # The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops 346 # to (un)box the component tensors in a way that allows for batching & 347 # unbatching. 348 @property 349 def _flat_tensor_specs(self): 350 # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`, 351 # but a `SparseTensorSpec` can also represent a batch of boxed 352 # `SparseTensor` objects with shape `(..., 3)` (and batches of batches, 353 # etc.), so the flat shape must be unknown. 354 return [tensor_spec.TensorSpec(None, dtypes.variant)] 355 356 def _to_tensor_list(self, value): 357 value = SparseTensor.from_value(value) 358 return [gen_sparse_ops.serialize_sparse( 359 value.indices, value.values, value.dense_shape, 360 out_type=dtypes.variant)] 361 362 def _to_batched_tensor_list(self, value): 363 dense_shape = tensor_util.constant_value_as_shape(value.dense_shape) 364 if self._shape.merge_with(dense_shape).ndims == 0: 365 raise ValueError( 366 "Unbatching a sparse tensor is only supported for rank >= 1. " 367 f"Obtained input: {value}.") 368 return [gen_sparse_ops.serialize_many_sparse( 369 value.indices, value.values, value.dense_shape, 370 out_type=dtypes.variant)] 371 372 def _from_compatible_tensor_list(self, tensor_list): 373 tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype) 374 indices, values, dense_shape = tensor_list 375 rank = self._shape.ndims 376 indices.set_shape([None, rank]) 377 # We restore the dense_shape from the SparseTypeSpec. This is necessary 378 # for shape inference when using placeholder SparseTensors in function 379 # tracing. 380 if self._shape.is_fully_defined(): 381 dense_shape = ops.convert_to_tensor( 382 self._shape, dtype=dtypes.int64, name="shape") 383 elif (self._shape.rank is not None and 384 any(dim.value is not None for dim in self._shape.dims)): 385 # array_ops imports sparse_tensor.py. Local import to avoid import cycle. 386 from tensorflow.python.ops import array_ops # pylint: disable=g-import-not-at-top 387 pieces = array_ops.unstack(dense_shape, num=self._shape.rank) 388 for i, dim in enumerate(self._shape.dims): 389 if dim.value is not None: 390 pieces[i] = constant_op.constant(dim.value, dense_shape.dtype) 391 dense_shape = array_ops.stack(pieces) 392 else: 393 dense_shape.set_shape([rank]) 394 395 return SparseTensor(indices, values, dense_shape) 396 397 def _batch(self, batch_size): 398 return SparseTensorSpec( 399 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 400 self._dtype) 401 402 def _unbatch(self): 403 if self._shape.ndims == 0: 404 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 405 return SparseTensorSpec(self._shape[1:], self._dtype) 406 407 def _to_legacy_output_types(self): 408 return self._dtype 409 410 def _to_legacy_output_shapes(self): 411 return self._shape 412 413 def _to_legacy_output_classes(self): 414 return SparseTensor 415 416 @classmethod 417 def from_value(cls, value): 418 if isinstance(value, SparseTensor): 419 return cls(value.shape, value.dtype) 420 if isinstance(value, SparseTensorValue): 421 if isinstance(value.values, np.ndarray): 422 return cls(value.dense_shape, value.values.dtype) 423 else: 424 return cls.from_value(SparseTensor.from_value(value)) 425 else: 426 raise TypeError("Expected SparseTensor or SparseTensorValue. Received: " 427 f"{value} of type {type(value).__name__}.") 428 429 430# TODO(b/133606651) Delete the SparseTensor registration when CompositeTensor 431# is updated to define a _type_spec field (since registration will be 432# automatic). Do *not* delete the SparseTensorValue registration. 433type_spec.register_type_spec_from_value_converter( 434 SparseTensor, SparseTensorSpec.from_value) 435type_spec.register_type_spec_from_value_converter( 436 SparseTensorValue, SparseTensorSpec.from_value) 437 438 439@tf_export(v1=["convert_to_tensor_or_sparse_tensor"]) 440def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None): 441 """Converts value to a `SparseTensor` or `Tensor`. 442 443 Args: 444 value: A `SparseTensor`, `SparseTensorValue`, or an object whose type has a 445 registered `Tensor` conversion function. 446 dtype: Optional element type for the returned tensor. If missing, the type 447 is inferred from the type of `value`. 448 name: Optional name to use if a new `Tensor` is created. 449 450 Returns: 451 A `SparseTensor` or `Tensor` based on `value`. 452 453 Raises: 454 RuntimeError: If result type is incompatible with `dtype`. 455 """ 456 if dtype is not None: 457 dtype = dtypes.as_dtype(dtype) 458 if isinstance(value, SparseTensorValue): 459 value = SparseTensor.from_value(value) 460 if isinstance(value, SparseTensor): 461 if dtype and not dtype.is_compatible_with(value.dtype): 462 raise RuntimeError(f"Sparse dtype mismatch. Requested: {dtype.name}, " 463 f" Actual: {value.dtype.name}") 464 return value 465 return ops.convert_to_tensor(value, dtype=dtype, name=name) 466 467 468def is_sparse(x): 469 """Check whether `x` is sparse. 470 471 Check whether an object is a `tf.sparse.SparseTensor` or 472 `tf.compat.v1.SparseTensorValue`. 473 474 Args: 475 x: A python object to check. 476 477 Returns: 478 `True` iff `x` is a `tf.sparse.SparseTensor` or 479 `tf.compat.v1.SparseTensorValue`. 480 """ 481 return isinstance(x, (SparseTensor, SparseTensorValue)) 482