1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Indexed slices.""" 16 17# pylint: disable=g-bad-name 18import collections 19import warnings 20 21import numpy as np 22 23from tensorflow.python import tf2 24from tensorflow.python.eager import context 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import composite_tensor_gradient 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import tensor_conversion_registry 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import type_spec 31from tensorflow.python.types import internal 32from tensorflow.python.util.compat import collections_abc 33from tensorflow.python.util.lazy_loader import LazyLoader 34from tensorflow.python.util.tf_export import tf_export 35 36 37# Use LazyLoader to avoid circular dependencies. 38# 39# Note: these can all be changed to regular imports once all code has been 40# updated to refer the symbols defined in this module directly, rather than 41# using the backwards-compatible aliases in ops.py. (E.g., 42# "indexed_slices.IndexedSlices" rather than "ops.IndexedSlices".) 43math_ops = LazyLoader( 44 "math_ops", globals(), 45 "tensorflow.python.ops.math_ops") 46ops = LazyLoader( 47 "ops", globals(), "tensorflow.python.framework.ops") 48tensor_spec = LazyLoader( 49 "tensor_spec", globals(), 50 "tensorflow.python.framework.tensor_spec") 51tensor_util = LazyLoader( 52 "tensor_util", globals(), 53 "tensorflow.python.framework.tensor_util") 54 55 56class IndexedSlicesCompositeTensorGradient( 57 composite_tensor_gradient.CompositeTensorGradient): 58 """CompositeTensorGradient for IndexedSlices.""" 59 60 def get_gradient_components(self, value): 61 return value.values 62 63 def replace_gradient_components(self, value, component_grads): 64 return IndexedSlices(component_grads, value.indices, value.dense_shape) 65 66 67# TODO(mdan): Should IndexedSlices be a "tensor"? 68@tf_export("IndexedSlices") 69class IndexedSlices(internal.NativeObject, composite_tensor.CompositeTensor): 70 """A sparse representation of a set of tensor slices at given indices. 71 72 This class is a simple wrapper for a pair of `Tensor` objects: 73 74 * `values`: A `Tensor` of any dtype with shape `[D0, D1, ..., Dn]`. 75 * `indices`: A 1-D integer `Tensor` with shape `[D0]`. 76 77 An `IndexedSlices` is typically used to represent a subset of a larger 78 tensor `dense` of shape `[LARGE0, D1, .. , DN]` where `LARGE0 >> D0`. 79 The values in `indices` are the indices in the first dimension of 80 the slices that have been extracted from the larger tensor. 81 82 The dense tensor `dense` represented by an `IndexedSlices` `slices` has 83 84 ```python 85 dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...] 86 ``` 87 88 The `IndexedSlices` class is used principally in the definition of 89 gradients for operations that have sparse gradients 90 (e.g. `tf.gather`). 91 92 >>> v = tf.Variable([[0.,1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 8]]) 93 >>> with tf.GradientTape() as tape: 94 ... r = tf.gather(v, [1,3]) 95 >>> index_slices = tape.gradient(r,v) 96 >>> index_slices 97 <...IndexedSlices object ...> 98 >>> index_slices.indices.numpy() 99 array([1, 3], dtype=int32) 100 >>> index_slices.values.numpy() 101 array([[1., 1., 1.], 102 [1., 1., 1.]], dtype=float32) 103 104 Contrast this representation with 105 `tf.sparse.SparseTensor`, 106 which uses multi-dimensional indices and scalar values. 107 """ 108 109 def __init__(self, values, indices, dense_shape=None): 110 """Creates an `IndexedSlices`.""" 111 self._values = values 112 self._indices = indices 113 self._dense_shape = dense_shape 114 115 @property 116 def values(self): 117 """A `Tensor` containing the values of the slices.""" 118 return self._values 119 120 @property 121 def indices(self): 122 """A 1-D `Tensor` containing the indices of the slices.""" 123 return self._indices 124 125 @property 126 def dense_shape(self): 127 """A 1-D `Tensor` containing the shape of the corresponding dense tensor.""" 128 return self._dense_shape 129 130 @property 131 def shape(self): 132 """Gets the `tf.TensorShape` representing the shape of the dense tensor. 133 134 Returns: 135 A `tf.TensorShape` object. 136 """ 137 if self._dense_shape is None: 138 return tensor_shape.TensorShape(None) 139 140 return tensor_util.constant_value_as_shape(self._dense_shape) 141 142 @property 143 def name(self): 144 """The name of this `IndexedSlices`.""" 145 return self.values.name 146 147 @property 148 def device(self): 149 """The name of the device on which `values` will be produced, or `None`.""" 150 return self.values.device 151 152 @property 153 def op(self): 154 """The `Operation` that produces `values` as an output.""" 155 return self.values.op 156 157 @property 158 def dtype(self): 159 """The `DType` of elements in this tensor.""" 160 return self.values.dtype 161 162 @property 163 def graph(self): 164 """The `Graph` that contains the values, indices, and shape tensors.""" 165 return self._values.graph 166 167 def __str__(self): 168 return "IndexedSlices(indices=%s, values=%s%s)" % ( 169 self._indices, self._values, 170 (", dense_shape=%s" % 171 (self._dense_shape,)) if self._dense_shape is not None else "") 172 173 def __neg__(self): 174 return IndexedSlices(-self.values, self.indices, self.dense_shape) 175 176 __composite_gradient__ = IndexedSlicesCompositeTensorGradient() 177 178 @property 179 def _type_spec(self): 180 indices_shape = self._indices.shape.merge_with(self._values.shape[:1]) 181 dense_shape = tensor_shape.TensorShape([None]).concatenate( 182 self._values.shape[1:]) 183 if self._dense_shape is not None: 184 dense_shape_dtype = self._dense_shape.dtype 185 dense_shape = dense_shape.merge_with( 186 tensor_util.constant_value_as_shape(self._dense_shape)) 187 else: 188 dense_shape_dtype = None 189 return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype, 190 dense_shape_dtype, indices_shape) 191 192 def _shape_invariant_to_type_spec(self, shape): 193 # From tf.while_loop docs: "If a loop variable is an IndexedSlices, the 194 # shape invariant must be a shape invariant of the values tensor of the 195 # IndexedSlices. It means the shapes of the three tensors of the 196 # IndexedSlices are (shape, [shape[0]], [shape.ndims])." 197 indices_shape = shape[:1] 198 dense_shape = tensor_shape.TensorShape([None]).concatenate(shape[1:]) 199 if self._dense_shape is None: 200 dense_shape_dtype = None 201 else: 202 dense_shape_dtype = self._dense_shape.dtype 203 return IndexedSlicesSpec(dense_shape, self.dtype, self._indices.dtype, 204 dense_shape_dtype, indices_shape) 205 206 def consumers(self): 207 return self._consumers() 208 209 210IndexedSlicesValue = collections.namedtuple( 211 "IndexedSlicesValue", ["values", "indices", "dense_shape"]) 212 213 214@tf_export("IndexedSlicesSpec") 215class IndexedSlicesSpec(type_spec.TypeSpec): 216 """Type specification for a `tf.IndexedSlices`.""" 217 218 __slots__ = ["_shape", "_values_dtype", "_indices_dtype", 219 "_dense_shape_dtype", "_indices_shape"] 220 221 value_type = property(lambda self: IndexedSlices) 222 223 def __init__(self, shape=None, dtype=dtypes.float32, 224 indices_dtype=dtypes.int64, dense_shape_dtype=None, 225 indices_shape=None): 226 """Constructs a type specification for a `tf.IndexedSlices`. 227 228 Args: 229 shape: The dense shape of the `IndexedSlices`, or `None` to allow any 230 dense shape. 231 dtype: `tf.DType` of values in the `IndexedSlices`. 232 indices_dtype: `tf.DType` of the `indices` in the `IndexedSlices`. One 233 of `tf.int32` or `tf.int64`. 234 dense_shape_dtype: `tf.DType` of the `dense_shape` in the `IndexedSlices`. 235 One of `tf.int32`, `tf.int64`, or `None` (if the `IndexedSlices` has 236 no `dense_shape` tensor). 237 indices_shape: The shape of the `indices` component, which indicates 238 how many slices are in the `IndexedSlices`. 239 """ 240 self._shape = tensor_shape.as_shape(shape) 241 self._values_dtype = dtypes.as_dtype(dtype) 242 self._indices_dtype = dtypes.as_dtype(indices_dtype) 243 if dense_shape_dtype is None: 244 self._dense_shape_dtype = None 245 else: 246 self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype) 247 self._indices_shape = tensor_shape.as_shape(indices_shape).with_rank(1) 248 249 def _serialize(self): 250 return (self._shape, self._values_dtype, self._indices_dtype, 251 self._dense_shape_dtype, self._indices_shape) 252 253 @property 254 def _component_specs(self): 255 value_shape = self._indices_shape.concatenate(self._shape[1:]) 256 specs = [ 257 tensor_spec.TensorSpec(value_shape, self._values_dtype), 258 tensor_spec.TensorSpec(self._indices_shape, self._indices_dtype)] 259 if self._dense_shape_dtype is not None: 260 specs.append( 261 tensor_spec.TensorSpec([self._shape.ndims], self._dense_shape_dtype)) 262 return tuple(specs) 263 264 def _to_components(self, value): 265 if value.dense_shape is None: 266 return (value.values, value.indices) 267 else: 268 return (value.values, value.indices, value.dense_shape) 269 270 def _from_components(self, tensor_list): 271 if (all(isinstance(t, np.ndarray) for t in tensor_list) and 272 not tf2.enabled()): 273 if len(tensor_list) == 2: 274 return IndexedSlicesValue(tensor_list[0], tensor_list[1], None) 275 else: 276 return IndexedSlicesValue(*tensor_list) 277 else: 278 return IndexedSlices(*tensor_list) 279 280 281@tf_export(v1=["convert_to_tensor_or_indexed_slices"]) 282def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): 283 """Converts the given object to a `Tensor` or an `IndexedSlices`. 284 285 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 286 unmodified. Otherwise, it is converted to a `Tensor` using 287 `convert_to_tensor()`. 288 289 Args: 290 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 291 by `convert_to_tensor()`. 292 dtype: (Optional.) The required `DType` of the returned `Tensor` or 293 `IndexedSlices`. 294 name: (Optional.) A name to use if a new `Tensor` is created. 295 296 Returns: 297 A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 298 299 Raises: 300 ValueError: If `dtype` does not match the element type of `value`. 301 """ 302 return internal_convert_to_tensor_or_indexed_slices( 303 value=value, dtype=dtype, name=name, as_ref=False) 304 305 306def internal_convert_to_tensor_or_indexed_slices(value, 307 dtype=None, 308 name=None, 309 as_ref=False): 310 """Converts the given object to a `Tensor` or an `IndexedSlices`. 311 312 If `value` is an `IndexedSlices` or `SparseTensor` it is returned 313 unmodified. Otherwise, it is converted to a `Tensor` using 314 `convert_to_tensor()`. 315 316 Args: 317 value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed 318 by `convert_to_tensor()`. 319 dtype: (Optional.) The required `DType` of the returned `Tensor` or 320 `IndexedSlices`. 321 name: (Optional.) A name to use if a new `Tensor` is created. 322 as_ref: True if the caller wants the results as ref tensors. 323 324 Returns: 325 A `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. 326 327 Raises: 328 ValueError: If `dtype` does not match the element type of `value`. 329 """ 330 if isinstance(value, ops.EagerTensor) and not context.executing_eagerly(): 331 return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) 332 # TODO(mdan): Name says tensor_or_indexed_slices. So do explicitly just that? 333 elif isinstance(value, internal.NativeObject): 334 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): 335 raise ValueError( 336 "Incompatible tensor conversion requested to `dtype` " 337 f"{dtypes.as_dtype(dtype).name} for `value` ({value}) with dtype" 338 f" {value.dtype.name}.") 339 return value 340 else: 341 return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref) 342 343 344def internal_convert_n_to_tensor_or_indexed_slices(values, 345 dtype=None, 346 name=None, 347 as_ref=False): 348 """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. 349 350 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 351 unmodified. 352 353 Args: 354 values: An iterable of `None`, `IndexedSlices`, `SparseTensor`, or objects 355 that can be consumed by `convert_to_tensor()`. 356 dtype: (Optional.) The required `DType` of the returned `Tensor` or 357 `IndexedSlices`. 358 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 359 which case element `i` will be given the name `name + '_' + i`. 360 as_ref: True if the caller wants the results as ref tensors. 361 362 Returns: 363 A list of `Tensor`, `IndexedSlices`, `SparseTensor` and/or `None` objects. 364 365 Raises: 366 TypeError: If no conversion function is registered for an element in 367 `values`. 368 RuntimeError: If a registered conversion function returns an invalid 369 value. 370 """ 371 if not isinstance(values, collections_abc.Iterable): 372 raise TypeError("Argument `values` must be iterable.") 373 ret = [] 374 for i, value in enumerate(values): 375 if value is None: 376 ret.append(value) 377 else: 378 n = None if name is None else "%s_%d" % (name, i) 379 ret.append( 380 internal_convert_to_tensor_or_indexed_slices( 381 value, dtype=dtype, name=n, as_ref=as_ref)) 382 return ret 383 384 385def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None): 386 """Converts `values` to a list of `Output` or `IndexedSlices` objects. 387 388 Any `IndexedSlices` or `SparseTensor` objects in `values` are returned 389 unmodified. 390 391 Args: 392 values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that 393 can be consumed by `convert_to_tensor()`. 394 dtype: (Optional.) The required `DType` of the returned `Tensor` 395 `IndexedSlices`. 396 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 397 which case element `i` will be given the name `name + '_' + i`. 398 399 Returns: 400 A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects. 401 402 Raises: 403 TypeError: If no conversion function is registered for an element in 404 `values`. 405 RuntimeError: If a registered conversion function returns an invalid 406 value. 407 """ 408 return internal_convert_n_to_tensor_or_indexed_slices( 409 values=values, dtype=dtype, name=name, as_ref=False) 410 411 412# Warn the user if we convert a sparse representation to dense with at 413# least this number of elements. 414_LARGE_SPARSE_NUM_ELEMENTS = 100000000 415 416 417def _indexed_slices_to_tensor(value, dtype=None, name=None, as_ref=False): 418 """Converts an IndexedSlices object `value` to a Tensor. 419 420 NOTE(mrry): This function is potentially expensive. 421 422 Args: 423 value: An ops.IndexedSlices object. 424 dtype: The dtype of the Tensor to be returned. 425 name: Optional name to use for the returned Tensor. 426 as_ref: True if a ref is requested. 427 428 Returns: 429 A dense Tensor representing the values in the given IndexedSlices. 430 431 Raises: 432 ValueError: If the IndexedSlices does not have the same dtype. 433 """ 434 _ = as_ref 435 if dtype and not dtype.is_compatible_with(value.dtype): 436 raise ValueError( 437 f"Incompatible tensor conversion requested to `dtype` {dtype.name} for " 438 f"IndexedSlices ({value}) with dtype {value.dtype.name}") 439 if value.dense_shape is None: 440 raise ValueError( 441 "Tensor conversion requested for IndexedSlices for argument `value` " 442 f"without dense_shape: {value!s}") 443 # TODO(mrry): Consider adding static shape information to 444 # IndexedSlices, to avoid using numpy here. 445 if not context.executing_eagerly(): 446 dense_shape_value = tensor_util.constant_value(value.dense_shape) 447 if dense_shape_value is not None: 448 num_elements = np.prod(dense_shape_value) 449 if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS: 450 warnings.warn( 451 "Converting sparse IndexedSlices to a dense Tensor with %d " 452 "elements. This may consume a large amount of memory." % 453 num_elements) 454 else: 455 if value.dense_shape.op.type != "VariableShape": 456 # VariableShape may hide static shapes behind a resource handle 457 # producing a warning that isn't that useful to users. 458 warnings.warn( 459 "Converting sparse IndexedSlices(%s) to a dense Tensor of unknown " 460 "shape. This may consume a large amount of memory." % value) 461 return math_ops.unsorted_segment_sum( 462 value.values, value.indices, value.dense_shape[0], name=name) 463 464 465tensor_conversion_registry.register_tensor_conversion_function( 466 IndexedSlices, _indexed_slices_to_tensor) 467