xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/indexed_slices.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Indexed slices."""
16
17# pylint: disable=g-bad-name
18import collections
19import warnings
20
21import numpy as np
22
23from tensorflow.python import tf2
24from tensorflow.python.eager import context
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import composite_tensor_gradient
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import tensor_conversion_registry
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import type_spec
31from tensorflow.python.types import internal
32from tensorflow.python.util.compat import collections_abc
33from tensorflow.python.util.lazy_loader import LazyLoader
34from tensorflow.python.util.tf_export import tf_export
35
36
37# Use LazyLoader to avoid circular dependencies.
38#
39# Note: these can all be changed to regular imports once all code has been
40# updated to refer the symbols defined in this module directly, rather than
41# using the backwards-compatible aliases in ops.py.  (E.g.,
42# "indexed_slices.IndexedSlices" rather than "ops.IndexedSlices".)
43math_ops = LazyLoader(
44    "math_ops", globals(),
45    "tensorflow.python.ops.math_ops")
46ops = LazyLoader(
47    "ops", globals(), "tensorflow.python.framework.ops")
48tensor_spec = LazyLoader(
49    "tensor_spec", globals(),
50    "tensorflow.python.framework.tensor_spec")
51tensor_util = LazyLoader(
52    "tensor_util", globals(),
53    "tensorflow.python.framework.tensor_util")
54
55
56class IndexedSlicesCompositeTensorGradient(
57    composite_tensor_gradient.CompositeTensorGradient):
58  """CompositeTensorGradient for IndexedSlices."""
59
60  def get_gradient_components(self, value):
61    return value.values
62
63  def replace_gradient_components(self, value, component_grads):
64    return IndexedSlices(component_grads, value.indices, value.dense_shape)
65
66
67# TODO(mdan): Should IndexedSlices be a "tensor"?
68@tf_export("IndexedSlices")
69class IndexedSlices(internal.NativeObject, composite_tensor.CompositeTensor):
70  """A sparse representation of a set of tensor slices at given indices.
71
72  This class is a simple wrapper for a pair of `Tensor` objects:
73
74  * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`.
75  * `indices`: A 1-D integer `Tensor` with shape `[D0]`.
76
77  An `IndexedSlices` is typically used to represent a subset of a larger
78  tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`.
79  The values in `indices` are the indices in the first dimension of
80  the slices that have been extracted from the larger tensor.
81
82  The dense tensor `dense` represented by an `IndexedSlices` `slices` has
83
84  ```python
85  dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]
86  ```
87
88  The `IndexedSlices` class is used principally in the definition of
89  gradients for operations that have sparse gradients
90  (e.g. `tf.gather`).
91
92  >>> v = tf.Variable([[0.,1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 8]])
93  >>> with tf.GradientTape() as tape:
94  ...   r = tf.gather(v, [1,3])
95  >>> index_slices = tape.gradient(r,v)
96  >>> index_slices
97  <...IndexedSlices object ...>
98  >>> index_slices.indices.numpy()
99  array([1, 3], dtype=int32)
100  >>> index_slices.values.numpy()
101  array([[1., 1., 1.],
102         [1., 1., 1.]], dtype=float32)
103
104  Contrast this representation with
105  `tf.sparse.SparseTensor`,
106  which uses multi-dimensional indices and scalar values.
107  """
108
109  def __init__(self, values, indices, dense_shape=None):
110    """Creates an `IndexedSlices`."""
111    self._values = values
112    self._indices = indices
113    self._dense_shape = dense_shape
114
115  @property
116  def values(self):
117    """A `Tensor` containing the values of the slices."""
118    return self._values
119
120  @property
121  def indices(self):
122    """A 1-D `Tensor` containing the indices of the slices."""
123    return self._indices
124
125  @property
126  def dense_shape(self):
127    """A 1-D `Tensor` containing the shape of the corresponding dense tensor."""
128    return self._dense_shape
129
130  @property
131  def shape(self):
132    """Gets the `tf.TensorShape` representing the shape of the dense tensor.
133
134    Returns:
135      A `tf.TensorShape` object.
136    """
137    if self._dense_shape is None:
138      return tensor_shape.TensorShape(None)
139
140    return tensor_util.constant_value_as_shape(self._dense_shape)
141
142  @property
143  def name(self):
144    """The name of this `IndexedSlices`."""
145    return self.values.name
146
147  @property
148  def device(self):
149    """The name of the device on which `values` will be produced, or `None`."""
150    return self.values.device
151
152  @property
153  def op(self):
154    """The `Operation` that produces `values` as an output."""
155    return self.values.op
156
157  @property
158  def dtype(self):
159    """The `DType` of elements in this tensor."""
160    return self.values.dtype
161
162  @property
163  def graph(self):
164    """The `Graph` that contains the values, indices, and shape tensors."""
165    return self._values.graph
166
167  def __str__(self):
168    return "IndexedSlices(indices=%s, values=%s%s)" % (
169        self._indices, self._values,
170        (", dense_shape=%s" %
171         (self._dense_shape,)) if self._dense_shape is not None else "")
172
173  def __neg__(self):
174    return IndexedSlices(-self.values, self.indices, self.dense_shape)
175
176  __composite_gradient__ = IndexedSlicesCompositeTensorGradient()
177
178  @property
179  def _type_spec(self):
180    indices_shape = self._indices.shape.merge_with(self._values.shape[:1])
181    dense_shape = tensor_shape.TensorShape([None]).concatenate(
182        self._values.shape[1:])
183    if self._dense_shape is not None:
184      dense_shape_dtype = self._dense_shape.dtype
185      dense_shape = dense_shape.merge_with(
186          tensor_util.constant_value_as_shape(self._dense_shape))
187    else:
188      dense_shape_dtype = None
189    return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype,
190                             dense_shape_dtype, indices_shape)
191
192  def _shape_invariant_to_type_spec(self, shape):
193    # From tf.while_loop docs: "If a loop variable is an IndexedSlices, the
194    # shape invariant must be a shape invariant of the values tensor of the
195    # IndexedSlices. It means the shapes of the three tensors of the
196    # IndexedSlices are (shape, [shape[0]], [shape.ndims])."
197    indices_shape = shape[:1]
198    dense_shape = tensor_shape.TensorShape([None]).concatenate(shape[1:])
199    if self._dense_shape is None:
200      dense_shape_dtype = None
201    else:
202      dense_shape_dtype = self._dense_shape.dtype
203    return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype,
204                             dense_shape_dtype, indices_shape)
205
206  def consumers(self):
207    return self._consumers()
208
209
210IndexedSlicesValue = collections.namedtuple(
211    "IndexedSlicesValue", ["values", "indices", "dense_shape"])
212
213
214@tf_export("IndexedSlicesSpec")
215class IndexedSlicesSpec(type_spec.TypeSpec):
216  """Type specification for a `tf.IndexedSlices`."""
217
218  __slots__ = ["_shape", "_values_dtype", "_indices_dtype",
219               "_dense_shape_dtype", "_indices_shape"]
220
221  value_type = property(lambda self: IndexedSlices)
222
223  def __init__(self, shape=None, dtype=dtypes.float32,
224               indices_dtype=dtypes.int64, dense_shape_dtype=None,
225               indices_shape=None):
226    """Constructs a type specification for a `tf.IndexedSlices`.
227
228    Args:
229      shape: The dense shape of the `IndexedSlices`, or `None` to allow any
230        dense shape.
231      dtype: `tf.DType` of values in the `IndexedSlices`.
232      indices_dtype: `tf.DType` of the `indices` in the `IndexedSlices`.  One
233        of `tf.int32` or `tf.int64`.
234      dense_shape_dtype: `tf.DType` of the `dense_shape` in the `IndexedSlices`.
235        One of `tf.int32`, `tf.int64`, or `None` (if the `IndexedSlices` has
236        no `dense_shape` tensor).
237      indices_shape: The shape of the `indices` component, which indicates
238        how many slices are in the `IndexedSlices`.
239    """
240    self._shape = tensor_shape.as_shape(shape)
241    self._values_dtype = dtypes.as_dtype(dtype)
242    self._indices_dtype = dtypes.as_dtype(indices_dtype)
243    if dense_shape_dtype is None:
244      self._dense_shape_dtype = None
245    else:
246      self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype)
247    self._indices_shape = tensor_shape.as_shape(indices_shape).with_rank(1)
248
249  def _serialize(self):
250    return (self._shape, self._values_dtype, self._indices_dtype,
251            self._dense_shape_dtype, self._indices_shape)
252
253  @property
254  def _component_specs(self):
255    value_shape = self._indices_shape.concatenate(self._shape[1:])
256    specs = [
257        tensor_spec.TensorSpec(value_shape, self._values_dtype),
258        tensor_spec.TensorSpec(self._indices_shape, self._indices_dtype)]
259    if self._dense_shape_dtype is not None:
260      specs.append(
261          tensor_spec.TensorSpec([self._shape.ndims], self._dense_shape_dtype))
262    return tuple(specs)
263
264  def _to_components(self, value):
265    if value.dense_shape is None:
266      return (value.values, value.indices)
267    else:
268      return (value.values, value.indices, value.dense_shape)
269
270  def _from_components(self, tensor_list):
271    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
272        not tf2.enabled()):
273      if len(tensor_list) == 2:
274        return IndexedSlicesValue(tensor_list[0], tensor_list[1], None)
275      else:
276        return IndexedSlicesValue(*tensor_list)
277    else:
278      return IndexedSlices(*tensor_list)
279
280
281@tf_export(v1=["convert_to_tensor_or_indexed_slices"])
282def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
283  """Converts the given object to a `Tensor` or an `IndexedSlices`.
284
285  If `value` is an `IndexedSlices` or `SparseTensor` it is returned
286  unmodified. Otherwise, it is converted to a `Tensor` using
287  `convert_to_tensor()`.
288
289  Args:
290    value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
291      by `convert_to_tensor()`.
292    dtype: (Optional.) The required `DType` of the returned `Tensor` or
293      `IndexedSlices`.
294    name: (Optional.) A name to use if a new `Tensor` is created.
295
296  Returns:
297    A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
298
299  Raises:
300    ValueError: If `dtype` does not match the element type of `value`.
301  """
302  return internal_convert_to_tensor_or_indexed_slices(
303      value=value, dtype=dtype, name=name, as_ref=False)
304
305
306def internal_convert_to_tensor_or_indexed_slices(value,
307                                                 dtype=None,
308                                                 name=None,
309                                                 as_ref=False):
310  """Converts the given object to a `Tensor` or an `IndexedSlices`.
311
312  If `value` is an `IndexedSlices` or `SparseTensor` it is returned
313  unmodified. Otherwise, it is converted to a `Tensor` using
314  `convert_to_tensor()`.
315
316  Args:
317    value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
318      by `convert_to_tensor()`.
319    dtype: (Optional.) The required `DType` of the returned `Tensor` or
320      `IndexedSlices`.
321    name: (Optional.) A name to use if a new `Tensor` is created.
322    as_ref: True if the caller wants the results as ref tensors.
323
324  Returns:
325    A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
326
327  Raises:
328    ValueError: If `dtype` does not match the element type of `value`.
329  """
330  if isinstance(value, ops.EagerTensor) and not context.executing_eagerly():
331    return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
332  # TODO(mdan): Name says tensor_or_indexed_slices. So do explicitly just that?
333  elif isinstance(value, internal.NativeObject):
334    if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
335      raise ValueError(
336          "Incompatible tensor conversion requested to `dtype` "
337          f"{dtypes.as_dtype(dtype).name} for `value` ({value}) with dtype"
338          f" {value.dtype.name}.")
339    return value
340  else:
341    return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
342
343
344def internal_convert_n_to_tensor_or_indexed_slices(values,
345                                                   dtype=None,
346                                                   name=None,
347                                                   as_ref=False):
348  """Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
349
350  Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
351  unmodified.
352
353  Args:
354    values: An iterable of `None`, `IndexedSlices`, `SparseTensor`, or objects
355      that can be consumed by `convert_to_tensor()`.
356    dtype: (Optional.) The required `DType` of the returned `Tensor` or
357      `IndexedSlices`.
358    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
359      which case element `i` will be given the name `name + '_' + i`.
360    as_ref: True if the caller wants the results as ref tensors.
361
362  Returns:
363    A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects.
364
365  Raises:
366    TypeError: If no conversion function is registered for an element in
367      `values`.
368    RuntimeError: If a registered conversion function returns an invalid
369      value.
370  """
371  if not isinstance(values, collections_abc.Iterable):
372    raise TypeError("Argument `values` must be iterable.")
373  ret = []
374  for i, value in enumerate(values):
375    if value is None:
376      ret.append(value)
377    else:
378      n = None if name is None else "%s_%d" % (name, i)
379      ret.append(
380          internal_convert_to_tensor_or_indexed_slices(
381              value, dtype=dtype, name=n, as_ref=as_ref))
382  return ret
383
384
385def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None):
386  """Converts `values` to a list of `Output` or `IndexedSlices` objects.
387
388  Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
389  unmodified.
390
391  Args:
392    values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
393      can be consumed by `convert_to_tensor()`.
394    dtype: (Optional.) The required `DType` of the returned `Tensor`
395      `IndexedSlices`.
396    name: (Optional.) A name prefix to used when a new `Tensor` is created, in
397      which case element `i` will be given the name `name + '_' + i`.
398
399  Returns:
400    A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
401
402  Raises:
403    TypeError: If no conversion function is registered for an element in
404      `values`.
405    RuntimeError: If a registered conversion function returns an invalid
406      value.
407  """
408  return internal_convert_n_to_tensor_or_indexed_slices(
409      values=values, dtype=dtype, name=name, as_ref=False)
410
411
412# Warn the user if we convert a sparse representation to dense with at
413# least this number of elements.
414_LARGE_SPARSE_NUM_ELEMENTS = 100000000
415
416
417def _indexed_slices_to_tensor(value, dtype=None, name=None, as_ref=False):
418  """Converts an IndexedSlices object `value` to a Tensor.
419
420  NOTE(mrry): This function is potentially expensive.
421
422  Args:
423    value: An ops.IndexedSlices object.
424    dtype: The dtype of the Tensor to be returned.
425    name: Optional name to use for the returned Tensor.
426    as_ref: True if a ref is requested.
427
428  Returns:
429    A dense Tensor representing the values in the given IndexedSlices.
430
431  Raises:
432    ValueError: If the IndexedSlices does not have the same dtype.
433  """
434  _ = as_ref
435  if dtype and not dtype.is_compatible_with(value.dtype):
436    raise ValueError(
437        f"Incompatible tensor conversion requested to `dtype` {dtype.name} for "
438        f"IndexedSlices ({value}) with dtype {value.dtype.name}")
439  if value.dense_shape is None:
440    raise ValueError(
441        "Tensor conversion requested for IndexedSlices for argument `value` "
442        f"without dense_shape: {value!s}")
443  # TODO(mrry): Consider adding static shape information to
444  # IndexedSlices, to avoid using numpy here.
445  if not context.executing_eagerly():
446    dense_shape_value = tensor_util.constant_value(value.dense_shape)
447    if dense_shape_value is not None:
448      num_elements = np.prod(dense_shape_value)
449      if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
450        warnings.warn(
451            "Converting sparse IndexedSlices to a dense Tensor with %d "
452            "elements. This may consume a large amount of memory." %
453            num_elements)
454    else:
455      if value.dense_shape.op.type != "VariableShape":
456        # VariableShape may hide static shapes behind a resource handle
457        # producing a warning that isn't that useful to users.
458        warnings.warn(
459            "Converting sparse IndexedSlices(%s) to a dense Tensor of unknown "
460            "shape. This may consume a large amount of memory." % value)
461  return math_ops.unsorted_segment_sum(
462      value.values, value.indices, value.dense_shape[0], name=name)
463
464
465tensor_conversion_registry.register_tensor_conversion_function(
466    IndexedSlices, _indexed_slices_to_tensor)
467