xref: /aosp_15_r20/external/pytorch/torch/sparse/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# The Tensor classes are added to this module by python_tensor.cpp
3# A workaround to support both TorchScript and MyPy:
4from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
5
6import torch
7from torch import Tensor
8from torch._C import _add_docstr, _sparse  # type: ignore[attr-defined]
9
10# Semi structured sparsity support
11from .semi_structured import (
12    SparseSemiStructuredTensor,
13    SparseSemiStructuredTensorCUSPARSELT,
14    SparseSemiStructuredTensorCUTLASS,
15    to_sparse_semi_structured,
16)
17
18
19if TYPE_CHECKING:
20    from torch.types import _dtype as DType
21
22    DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]]
23else:
24    # The JIT doesn't understand Union, nor torch.dtype here
25    DType = int
26    DimOrDims = Optional[Tuple[int]]
27
28
29__all__ = [
30    "addmm",
31    "check_sparse_tensor_invariants",
32    "mm",
33    "sum",
34    "softmax",
35    "solve",
36    "log_softmax",
37    "SparseSemiStructuredTensor",
38    "SparseSemiStructuredTensorCUTLASS",
39    "SparseSemiStructuredTensorCUSPARSELT",
40    "to_sparse_semi_structured",
41    "as_sparse_gradcheck",
42]
43
44addmm = _add_docstr(
45    _sparse._sparse_addmm,
46    r"""
47sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
48
49This function does exact same thing as :func:`torch.addmm` in the forward,
50except that it supports backward for sparse COO matrix :attr:`mat1`.
51When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
52When inputs are COO tensors, this function also supports backward for both inputs.
53
54Supports both CSR and COO storage formats.
55
56.. note::
57    This function doesn't support computing derivaties with respect to CSR matrices.
58
59Args:
60    mat (Tensor): a dense matrix to be added
61    mat1 (Tensor): a sparse matrix to be multiplied
62    mat2 (Tensor): a dense matrix to be multiplied
63    beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
64    alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
65""",
66)
67
68
69mm = _add_docstr(
70    _sparse._sparse_mm,
71    r"""
72    Performs a matrix multiplication of the sparse matrix :attr:`mat1`
73    and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a
74    :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a
75    :math:`(n \times p)` tensor.
76    When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`.
77    When inputs are COO tensors, this function also supports backward for both inputs.
78
79    Supports both CSR and COO storage formats.
80
81.. note::
82    This function doesn't support computing derivaties with respect to CSR matrices.
83
84    This function also additionally accepts an optional :attr:`reduce` argument that allows
85    specification of an optional reduction operation, mathematically performs the following operation:
86
87.. math::
88
89    z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj}
90
91where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for
92CSR storage format on CPU device.
93
94Args:
95    mat1 (Tensor): the first sparse matrix to be multiplied
96    mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense
97    reduce (str, optional): the reduction operation to apply for non-unique indices
98        (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`.
99
100Shape:
101    The format of the output tensor of this function follows:
102    - sparse x sparse -> sparse
103    - sparse x dense -> dense
104
105Example::
106
107    >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_()
108    >>> a
109    tensor(indices=tensor([[0, 0, 1],
110                           [0, 2, 1]]),
111           values=tensor([1., 2., 3.]),
112           size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True)
113    >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True)
114    >>> b
115    tensor([[0., 1.],
116            [2., 0.],
117            [0., 0.]], requires_grad=True)
118    >>> y = torch.sparse.mm(a, b)
119    >>> y
120    tensor([[0., 1.],
121            [6., 0.]], grad_fn=<SparseAddmmBackward0>)
122    >>> y.sum().backward()
123    >>> a.grad
124    tensor(indices=tensor([[0, 0, 1],
125                           [0, 2, 1]]),
126           values=tensor([1., 0., 2.]),
127           size=(2, 3), nnz=3, layout=torch.sparse_coo)
128    >>> c = a.detach().to_sparse_csr()
129    >>> c
130    tensor(crow_indices=tensor([0, 2, 3]),
131           col_indices=tensor([0, 2, 1]),
132           values=tensor([1., 2., 3.]), size=(2, 3), nnz=3,
133           layout=torch.sparse_csr)
134    >>> y1 = torch.sparse.mm(c, b, 'sum')
135    >>> y1
136    tensor([[0., 1.],
137            [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
138    >>> y2 = torch.sparse.mm(c, b, 'max')
139    >>> y2
140    tensor([[0., 1.],
141            [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>)
142""",
143)
144
145
146sampled_addmm = _add_docstr(
147    _sparse.sparse_sampled_addmm,
148    r"""
149sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor
150
151Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations
152specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result.
153
154Mathematically this performs the following operation:
155
156.. math::
157
158    \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input}
159
160where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha`
161and :attr:`beta` are the scaling factors.
162:math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere.
163
164.. note::
165    :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors.
166
167Args:
168    input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute
169        the sampled matrix multiplication
170    mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied
171    mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied
172
173Keyword args:
174    beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`)
175    alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
176    out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
177
178Examples::
179
180    >>> input = torch.eye(3, device='cuda').to_sparse_csr()
181    >>> mat1 = torch.randn(3, 5, device='cuda')
182    >>> mat2 = torch.randn(5, 3, device='cuda')
183    >>> torch.sparse.sampled_addmm(input, mat1, mat2)
184    tensor(crow_indices=tensor([0, 1, 2, 3]),
185        col_indices=tensor([0, 1, 2]),
186        values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0',
187        size=(3, 3), nnz=3, layout=torch.sparse_csr)
188    >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense()
189    tensor([[ 0.2847,  0.0000,  0.0000],
190        [ 0.0000, -0.7805,  0.0000],
191        [ 0.0000,  0.0000, -0.1900]], device='cuda:0')
192    >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5)
193    tensor(crow_indices=tensor([0, 1, 2, 3]),
194        col_indices=tensor([0, 1, 2]),
195        values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0',
196        size=(3, 3), nnz=3, layout=torch.sparse_csr)
197""",
198)
199
200
201def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor:
202    r"""Return the sum of each row of the given sparse tensor.
203
204    Returns the sum of each row of the sparse tensor :attr:`input` in the given
205    dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions,
206    reduce over all of them. When sum over all ``sparse_dim``, this method
207    returns a dense tensor instead of a sparse tensor.
208
209    All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output
210    tensor having :attr:`dim` fewer dimensions than :attr:`input`.
211
212    During backward, only gradients at ``nnz`` locations of :attr:`input`
213    will propagate back. Note that the gradients of :attr:`input` is coalesced.
214
215    Args:
216        input (Tensor): the input sparse tensor
217        dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce
218            over all dims.
219        dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor.
220            Default: dtype of :attr:`input`.
221
222    Example::
223
224        >>> nnz = 3
225        >>> dims = [5, 5, 2, 3]
226        >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)),
227                           torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz)
228        >>> V = torch.randn(nnz, dims[2], dims[3])
229        >>> size = torch.Size(dims)
230        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
231        >>> S = torch.sparse_coo_tensor(I, V, size)
232        >>> S
233        tensor(indices=tensor([[2, 0, 3],
234                               [2, 4, 1]]),
235               values=tensor([[[-0.6438, -1.6467,  1.4004],
236                               [ 0.3411,  0.0918, -0.2312]],
237
238                              [[ 0.5348,  0.0634, -2.0494],
239                               [-0.7125, -1.0646,  2.1844]],
240
241                              [[ 0.1276,  0.1874, -0.6334],
242                               [-1.9682, -0.5340,  0.7483]]]),
243               size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo)
244
245        # when sum over only part of sparse_dims, return a sparse tensor
246        >>> torch.sparse.sum(S, [1, 3])
247        tensor(indices=tensor([[0, 2, 3]]),
248               values=tensor([[-1.4512,  0.4073],
249                              [-0.8901,  0.2017],
250                              [-0.3183, -1.7539]]),
251               size=(5, 2), nnz=3, layout=torch.sparse_coo)
252
253        # when sum over all sparse dim, return a dense tensor
254        # with summed dims squeezed
255        >>> torch.sparse.sum(S, [0, 1, 3])
256        tensor([-2.6596, -1.1450])
257    """
258    if dtype is None:
259        if dim is not None:
260            return torch._sparse_sum(input, dim)
261        else:
262            return torch._sparse_sum(input)
263    else:
264        if dim is not None:
265            return torch._sparse_sum(input, dim, dtype=dtype)
266        else:
267            return torch._sparse_sum(input, dtype=dtype)
268
269
270softmax = _add_docstr(
271    _sparse._sparse_softmax,
272    r"""
273sparse.softmax(input, dim, *, dtype=None) -> Tensor
274
275Applies a softmax function.
276
277Softmax is defined as:
278
279:math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
280
281where :math:`i, j` run over sparse tensor indices and unspecified
282entries are ignores. This is equivalent to defining unspecified
283entries as negative infinity so that :math:`exp(x_k) = 0` when the
284entry with index :math:`k` has not specified.
285
286It is applied to all slices along `dim`, and will re-scale them so
287that the elements lie in the range `[0, 1]` and sum to 1.
288
289Args:
290    input (Tensor): input
291    dim (int): A dimension along which softmax will be computed.
292    dtype (:class:`torch.dtype`, optional): the desired data type
293        of returned tensor.  If specified, the input tensor is
294        casted to :attr:`dtype` before the operation is
295        performed. This is useful for preventing data type
296        overflows. Default: None
297""",
298)
299
300
301spsolve = _add_docstr(
302    _sparse._spsolve,
303    r"""
304sparse.spsolve(input, other, *, left=True) -> Tensor
305
306Computes the solution of a square system of linear equations with
307a unique solution. Its purpose is similar to :func:`torch.linalg.solve`,
308except that the system is defined by a sparse CSR matrix with layout
309`sparse_csr`.
310
311Args:
312    input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the
313        coefficients of the linear system.
314    other (Tensor): a dense matrix of shape `(n, )` representing the right-hand
315        side of the linear system.
316    left (bool, optional): whether to solve the system for `input @ out = other`
317        (default) or `out @ input = other`. Only `left=True` is supported.
318""",
319)
320
321log_softmax = _add_docstr(
322    _sparse._sparse_log_softmax,
323    r"""
324sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
325
326Applies a softmax function followed by logarithm.
327
328See :class:`~torch.sparse.softmax` for more details.
329
330Args:
331    input (Tensor): input
332    dim (int): A dimension along which softmax will be computed.
333    dtype (:class:`torch.dtype`, optional): the desired data type
334        of returned tensor.  If specified, the input tensor is
335        casted to :attr:`dtype` before the operation is
336        performed. This is useful for preventing data type
337        overflows. Default: None
338""",
339)
340
341
342spdiags = _add_docstr(
343    _sparse._spdiags,
344    r"""
345sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
346
347Creates a sparse 2D tensor by placing the values from rows of
348:attr:`diagonals` along specified diagonals of the output
349
350The :attr:`offsets` tensor controls which diagonals are set.
351
352- If :attr:`offsets[i]` = 0, it is the main diagonal
353- If :attr:`offsets[i]` < 0, it is below the main diagonal
354- If :attr:`offsets[i]` > 0, it is above the main diagonal
355
356The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
357and an offset may not be repeated.
358
359Args:
360    diagonals (Tensor): Matrix storing diagonals row-wise
361    offsets (Tensor): The diagonals to be set, stored as a vector
362    shape (2-tuple of ints): The desired shape of the result
363Keyword args:
364    layout (:class:`torch.layout`, optional): The desired layout of the
365        returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
366        are supported. Default: ``torch.sparse_coo``
367
368Examples:
369
370Set the main and first two lower diagonals of a matrix::
371
372    >>> diags = torch.arange(9).reshape(3, 3)
373    >>> diags
374    tensor([[0, 1, 2],
375            [3, 4, 5],
376            [6, 7, 8]])
377    >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
378    >>> s
379    tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
380                           [0, 1, 2, 0, 1, 0]]),
381           values=tensor([0, 1, 2, 3, 4, 6]),
382           size=(3, 3), nnz=6, layout=torch.sparse_coo)
383    >>> s.to_dense()
384    tensor([[0, 0, 0],
385            [3, 1, 0],
386            [6, 4, 2]])
387
388
389Change the output layout::
390
391    >>> diags = torch.arange(9).reshape(3, 3)
392    >>> diags
393    tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
394    >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
395    >>> s
396    tensor(crow_indices=tensor([0, 1, 3, 6]),
397           col_indices=tensor([0, 0, 1, 0, 1, 2]),
398           values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
399           layout=torch.sparse_csr)
400    >>> s.to_dense()
401    tensor([[0, 0, 0],
402            [3, 1, 0],
403            [6, 4, 2]])
404
405Set partial diagonals of a large output::
406
407    >>> diags = torch.tensor([[1, 2], [3, 4]])
408    >>> offsets = torch.tensor([0, -1])
409    >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
410    tensor([[1, 0, 0, 0, 0],
411            [3, 2, 0, 0, 0],
412            [0, 4, 0, 0, 0],
413            [0, 0, 0, 0, 0],
414            [0, 0, 0, 0, 0]])
415
416.. note::
417
418    When setting the values along a given diagonal the index into the diagonal
419    and the index into the row of :attr:`diagonals` is taken as the
420    column index in the output. This has the effect that when setting a diagonal
421    with a positive offset `k` the first value along that diagonal will be
422    the value in position `k` of the row of :attr:`diagonals`
423
424Specifying a positive offset::
425
426    >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
427    >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
428    tensor([[1, 2, 3, 0, 0],
429            [0, 2, 3, 0, 0],
430            [0, 0, 3, 0, 0],
431            [0, 0, 0, 0, 0],
432            [0, 0, 0, 0, 0]])
433""",
434)
435
436
437class check_sparse_tensor_invariants:
438    """A tool to control checking sparse tensor invariants.
439
440    The following options exists to manage sparsr tensor invariants
441    checking in sparse tensor construction:
442
443    1. Using a context manager:
444
445       .. code:: python
446
447           with torch.sparse.check_sparse_tensor_invariants():
448               run_my_model()
449
450    2. Using a procedural approach:
451
452       .. code:: python
453
454           prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled()
455           torch.sparse.check_sparse_tensor_invariants.enable()
456
457           run_my_model()
458
459           if not prev_checks_enabled:
460               torch.sparse.check_sparse_tensor_invariants.disable()
461
462    3. Using function decoration:
463
464       .. code:: python
465
466           @torch.sparse.check_sparse_tensor_invariants()
467           def run_my_model():
468               ...
469
470           run_my_model()
471
472    4. Using ``check_invariants`` keyword argument in sparse tensor constructor call.
473       For example:
474
475       >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True)
476       Traceback (most recent call last):
477         File "<stdin>", line 1, in <module>
478       RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied.
479    """
480
481    @staticmethod
482    def is_enabled():
483        r"""Return True if the sparse tensor invariants checking is enabled.
484
485        .. note::
486
487            Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or
488            :func:`torch.sparse.check_sparse_tensor_invariants.disable` to
489            manage the state of the sparse tensor invariants checks.
490        """
491        return torch._C._check_sparse_tensor_invariants()
492
493    @staticmethod
494    def enable():
495        r"""Enable sparse tensor invariants checking in sparse tensor constructors.
496
497        .. note::
498
499            By default, the sparse tensor invariants checks are disabled. Use
500            :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to
501            retrieve the current state of sparse tensor invariants checking.
502
503        .. note::
504
505            The sparse tensor invariants check flag is effective to all sparse
506            tensor constructors, both in Python and ATen.
507
508        The flag can be locally overridden by the ``check_invariants``
509        optional argument of the sparse tensor constructor functions.
510        """
511        torch._C._set_check_sparse_tensor_invariants(True)
512
513    @staticmethod
514    def disable():
515        r"""Disable sparse tensor invariants checking in sparse tensor constructors.
516
517        See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information.
518        """
519        torch._C._set_check_sparse_tensor_invariants(False)
520
521    # context manager support
522    def __init__(self, enable=True):
523        self.state = enable
524        self.saved_state: Optional[bool] = None
525
526    def __enter__(self):
527        if self.saved_state is not None:
528            raise RuntimeError(
529                "This context manager instance is already activated."
530                " Use a different context manager instance for context nesting."
531            )
532        self.saved_state = self.is_enabled()
533        torch._C._set_check_sparse_tensor_invariants(self.state)
534
535    def __exit__(self, type, value, traceback):
536        assert self.saved_state is not None
537        torch._C._set_check_sparse_tensor_invariants(self.saved_state)
538        self.saved_state = None
539
540    # decorator support
541    def __call__(self, mth):
542        def test_mth(*args, **kwargs):
543            with type(self)(self.state):
544                return mth(*args, **kwargs)
545
546        return test_mth
547
548
549def as_sparse_gradcheck(gradcheck):
550    """Decorate function, to extend gradcheck for sparse tensors.
551
552    Decorator for torch.autograd.gradcheck or its functools.partial
553    variants that extends the gradcheck function with support to input
554    functions that operate on or/and return sparse tensors.
555
556    The specified gradcheck function itself is guaranteed to operate
557    on strided tensors only.
558
559    For example:
560
561    >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck)
562    >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True)
563    >>> gradcheck(lambda x: x.to_sparse_csr(), x)
564    True
565    """
566
567    def gradcheck_with_sparse_support(func, inputs, **kwargs):
568        """
569        Create gradcheck with support for sparse tensors.
570
571        Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support.
572        """
573        masked = kwargs.pop("masked", False)
574        sparse_layouts = {
575            torch.sparse_coo,
576            torch.sparse_csr,
577            torch.sparse_csc,
578            torch.sparse_bsr,
579            torch.sparse_bsc,
580        }
581        sparse_compressed_layouts = {
582            torch.sparse_csr,
583            torch.sparse_csc,
584            torch.sparse_bsr,
585            torch.sparse_bsc,
586        }
587        sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc}
588        STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__"
589
590        def convert_to_strided_representation(args):
591            """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors."""
592            if not isinstance(args, (list, tuple)):
593                args = (args,)
594            new_args: List[Any] = []
595            for obj in args:
596                if (
597                    isinstance(obj, torch.Tensor)
598                    and obj.requires_grad
599                    and obj.layout in sparse_layouts
600                ):
601                    d = dict(layout=obj.layout, shape=obj.shape)
602                    if not masked:
603                        # Materialize unspecified elements with zero values
604                        batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
605                        blocksize = (
606                            obj.values().shape[batch_dim + 1 : batch_dim + 3]
607                            if obj.layout in sparse_block_layouts
608                            else None
609                        )
610                        full_mask = torch.ones(
611                            obj.shape, device=obj.device, dtype=torch.bool
612                        ).to_sparse(
613                            layout=obj.layout,
614                            blocksize=blocksize,
615                            dense_dim=obj.dense_dim(),
616                        )
617                        obj = obj.to_dense().sparse_mask(full_mask)
618                    if obj.layout is torch.sparse_coo:
619                        d.update(
620                            indices=obj._indices(), is_coalesced=obj.is_coalesced()
621                        )
622                        values = obj._values()
623                    elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
624                        d.update(
625                            compressed_indices=obj.crow_indices(),
626                            plain_indices=obj.col_indices(),
627                        )
628                        values = obj.values()
629                    else:
630                        d.update(
631                            compressed_indices=obj.ccol_indices(),
632                            plain_indices=obj.row_indices(),
633                        )
634                        values = obj.values()
635                    new_args.extend(
636                        (STRIDED_REPRESENTATION, d, values.requires_grad_(True))
637                    )
638                else:
639                    new_args.append(obj)
640            return tuple(new_args)
641
642        def restore_from_strided_representation(args):
643            """Restore non-strided differentiable tensosr from their strided representations."""
644            new_args = []
645            args = list(args)
646            while args:
647                a = args.pop(0)
648                if a == STRIDED_REPRESENTATION:
649                    d, values = args.pop(0), args.pop(0)
650                    if d["layout"] is torch.sparse_coo:
651                        a = torch.sparse_coo_tensor(
652                            d["indices"],
653                            values,
654                            size=d["shape"],
655                            is_coalesced=d["is_coalesced"],
656                        )
657                    elif d["layout"] in sparse_compressed_layouts:
658                        a = torch.sparse_compressed_tensor(
659                            d["compressed_indices"],
660                            d["plain_indices"],
661                            values,
662                            size=d["shape"],
663                            layout=d["layout"],
664                        )
665                    else:
666                        raise NotImplementedError(
667                            f'conversion of {d["layout"]} strided representation to tensor'
668                        )
669                new_args.append(a)
670            return tuple(new_args)
671
672        def func_wrapper(*args, **kwargs):
673            restored_args = restore_from_strided_representation(args)
674
675            # convert differentiable output sparse tensors to strided
676            # tensors:
677            outputs = func(*restored_args, **kwargs)
678
679            strided_outputs = (
680                tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,)
681            )
682            strided_outputs = tuple(
683                (
684                    o.to_dense(masked_grad=masked)
685                    if isinstance(o, torch.Tensor)
686                    and o.requires_grad
687                    and o.layout in sparse_layouts
688                    else o
689                )
690                for o in strided_outputs
691            )
692
693            return (
694                strided_outputs
695                if isinstance(outputs, (list, tuple))
696                else strided_outputs[0]
697            )
698
699        args = (func_wrapper, convert_to_strided_representation(inputs))
700
701        return gradcheck(*args, **kwargs)
702
703    return gradcheck_with_sparse_support
704