1# mypy: allow-untyped-defs 2# The Tensor classes are added to this module by python_tensor.cpp 3# A workaround to support both TorchScript and MyPy: 4from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union 5 6import torch 7from torch import Tensor 8from torch._C import _add_docstr, _sparse # type: ignore[attr-defined] 9 10# Semi structured sparsity support 11from .semi_structured import ( 12 SparseSemiStructuredTensor, 13 SparseSemiStructuredTensorCUSPARSELT, 14 SparseSemiStructuredTensorCUTLASS, 15 to_sparse_semi_structured, 16) 17 18 19if TYPE_CHECKING: 20 from torch.types import _dtype as DType 21 22 DimOrDims = Optional[Union[int, Tuple[int, ...], List[int]]] 23else: 24 # The JIT doesn't understand Union, nor torch.dtype here 25 DType = int 26 DimOrDims = Optional[Tuple[int]] 27 28 29__all__ = [ 30 "addmm", 31 "check_sparse_tensor_invariants", 32 "mm", 33 "sum", 34 "softmax", 35 "solve", 36 "log_softmax", 37 "SparseSemiStructuredTensor", 38 "SparseSemiStructuredTensorCUTLASS", 39 "SparseSemiStructuredTensorCUSPARSELT", 40 "to_sparse_semi_structured", 41 "as_sparse_gradcheck", 42] 43 44addmm = _add_docstr( 45 _sparse._sparse_addmm, 46 r""" 47sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor 48 49This function does exact same thing as :func:`torch.addmm` in the forward, 50except that it supports backward for sparse COO matrix :attr:`mat1`. 51When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`. 52When inputs are COO tensors, this function also supports backward for both inputs. 53 54Supports both CSR and COO storage formats. 55 56.. note:: 57 This function doesn't support computing derivaties with respect to CSR matrices. 58 59Args: 60 mat (Tensor): a dense matrix to be added 61 mat1 (Tensor): a sparse matrix to be multiplied 62 mat2 (Tensor): a dense matrix to be multiplied 63 beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`) 64 alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) 65""", 66) 67 68 69mm = _add_docstr( 70 _sparse._sparse_mm, 71 r""" 72 Performs a matrix multiplication of the sparse matrix :attr:`mat1` 73 and the (sparse or strided) matrix :attr:`mat2`. Similar to :func:`torch.mm`, if :attr:`mat1` is a 74 :math:`(n \times m)` tensor, :attr:`mat2` is a :math:`(m \times p)` tensor, out will be a 75 :math:`(n \times p)` tensor. 76 When :attr:`mat1` is a COO tensor it must have `sparse_dim = 2`. 77 When inputs are COO tensors, this function also supports backward for both inputs. 78 79 Supports both CSR and COO storage formats. 80 81.. note:: 82 This function doesn't support computing derivaties with respect to CSR matrices. 83 84 This function also additionally accepts an optional :attr:`reduce` argument that allows 85 specification of an optional reduction operation, mathematically performs the following operation: 86 87.. math:: 88 89 z_{ij} = \bigoplus_{k = 0}^{K - 1} x_{ik} y_{kj} 90 91where :math:`\bigoplus` defines the reduce operator. :attr:`reduce` is implemented only for 92CSR storage format on CPU device. 93 94Args: 95 mat1 (Tensor): the first sparse matrix to be multiplied 96 mat2 (Tensor): the second matrix to be multiplied, which could be sparse or dense 97 reduce (str, optional): the reduction operation to apply for non-unique indices 98 (:obj:`"sum"`, :obj:`"mean"`, :obj:`"amax"`, :obj:`"amin"`). Default :obj:`"sum"`. 99 100Shape: 101 The format of the output tensor of this function follows: 102 - sparse x sparse -> sparse 103 - sparse x dense -> dense 104 105Example:: 106 107 >>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_() 108 >>> a 109 tensor(indices=tensor([[0, 0, 1], 110 [0, 2, 1]]), 111 values=tensor([1., 2., 3.]), 112 size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True) 113 >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True) 114 >>> b 115 tensor([[0., 1.], 116 [2., 0.], 117 [0., 0.]], requires_grad=True) 118 >>> y = torch.sparse.mm(a, b) 119 >>> y 120 tensor([[0., 1.], 121 [6., 0.]], grad_fn=<SparseAddmmBackward0>) 122 >>> y.sum().backward() 123 >>> a.grad 124 tensor(indices=tensor([[0, 0, 1], 125 [0, 2, 1]]), 126 values=tensor([1., 0., 2.]), 127 size=(2, 3), nnz=3, layout=torch.sparse_coo) 128 >>> c = a.detach().to_sparse_csr() 129 >>> c 130 tensor(crow_indices=tensor([0, 2, 3]), 131 col_indices=tensor([0, 2, 1]), 132 values=tensor([1., 2., 3.]), size=(2, 3), nnz=3, 133 layout=torch.sparse_csr) 134 >>> y1 = torch.sparse.mm(c, b, 'sum') 135 >>> y1 136 tensor([[0., 1.], 137 [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>) 138 >>> y2 = torch.sparse.mm(c, b, 'max') 139 >>> y2 140 tensor([[0., 1.], 141 [6., 0.]], grad_fn=<SparseMmReduceImplBackward0>) 142""", 143) 144 145 146sampled_addmm = _add_docstr( 147 _sparse.sparse_sampled_addmm, 148 r""" 149sparse.sampled_addmm(input, mat1, mat2, *, beta=1., alpha=1., out=None) -> Tensor 150 151Performs a matrix multiplication of the dense matrices :attr:`mat1` and :attr:`mat2` at the locations 152specified by the sparsity pattern of :attr:`input`. The matrix :attr:`input` is added to the final result. 153 154Mathematically this performs the following operation: 155 156.. math:: 157 158 \text{out} = \alpha\ (\text{mat1} \mathbin{@} \text{mat2})*\text{spy}(\text{input}) + \beta\ \text{input} 159 160where :math:`\text{spy}(\text{input})` is the sparsity pattern matrix of :attr:`input`, :attr:`alpha` 161and :attr:`beta` are the scaling factors. 162:math:`\text{spy}(\text{input})` has value 1 at the positions where :attr:`input` has non-zero values, and 0 elsewhere. 163 164.. note:: 165 :attr:`input` must be a sparse CSR tensor. :attr:`mat1` and :attr:`mat2` must be dense tensors. 166 167Args: 168 input (Tensor): a sparse CSR matrix of shape `(m, n)` to be added and used to compute 169 the sampled matrix multiplication 170 mat1 (Tensor): a dense matrix of shape `(m, k)` to be multiplied 171 mat2 (Tensor): a dense matrix of shape `(k, n)` to be multiplied 172 173Keyword args: 174 beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) 175 alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) 176 out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. 177 178Examples:: 179 180 >>> input = torch.eye(3, device='cuda').to_sparse_csr() 181 >>> mat1 = torch.randn(3, 5, device='cuda') 182 >>> mat2 = torch.randn(5, 3, device='cuda') 183 >>> torch.sparse.sampled_addmm(input, mat1, mat2) 184 tensor(crow_indices=tensor([0, 1, 2, 3]), 185 col_indices=tensor([0, 1, 2]), 186 values=tensor([ 0.2847, -0.7805, -0.1900]), device='cuda:0', 187 size=(3, 3), nnz=3, layout=torch.sparse_csr) 188 >>> torch.sparse.sampled_addmm(input, mat1, mat2).to_dense() 189 tensor([[ 0.2847, 0.0000, 0.0000], 190 [ 0.0000, -0.7805, 0.0000], 191 [ 0.0000, 0.0000, -0.1900]], device='cuda:0') 192 >>> torch.sparse.sampled_addmm(input, mat1, mat2, beta=0.5, alpha=0.5) 193 tensor(crow_indices=tensor([0, 1, 2, 3]), 194 col_indices=tensor([0, 1, 2]), 195 values=tensor([ 0.1423, -0.3903, -0.0950]), device='cuda:0', 196 size=(3, 3), nnz=3, layout=torch.sparse_csr) 197""", 198) 199 200 201def sum(input: Tensor, dim: DimOrDims = None, dtype: Optional[DType] = None) -> Tensor: 202 r"""Return the sum of each row of the given sparse tensor. 203 204 Returns the sum of each row of the sparse tensor :attr:`input` in the given 205 dimensions :attr:`dim`. If :attr:`dim` is a list of dimensions, 206 reduce over all of them. When sum over all ``sparse_dim``, this method 207 returns a dense tensor instead of a sparse tensor. 208 209 All summed :attr:`dim` are squeezed (see :func:`torch.squeeze`), resulting an output 210 tensor having :attr:`dim` fewer dimensions than :attr:`input`. 211 212 During backward, only gradients at ``nnz`` locations of :attr:`input` 213 will propagate back. Note that the gradients of :attr:`input` is coalesced. 214 215 Args: 216 input (Tensor): the input sparse tensor 217 dim (int or tuple of ints): a dimension or a list of dimensions to reduce. Default: reduce 218 over all dims. 219 dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. 220 Default: dtype of :attr:`input`. 221 222 Example:: 223 224 >>> nnz = 3 225 >>> dims = [5, 5, 2, 3] 226 >>> I = torch.cat([torch.randint(0, dims[0], size=(nnz,)), 227 torch.randint(0, dims[1], size=(nnz,))], 0).reshape(2, nnz) 228 >>> V = torch.randn(nnz, dims[2], dims[3]) 229 >>> size = torch.Size(dims) 230 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 231 >>> S = torch.sparse_coo_tensor(I, V, size) 232 >>> S 233 tensor(indices=tensor([[2, 0, 3], 234 [2, 4, 1]]), 235 values=tensor([[[-0.6438, -1.6467, 1.4004], 236 [ 0.3411, 0.0918, -0.2312]], 237 238 [[ 0.5348, 0.0634, -2.0494], 239 [-0.7125, -1.0646, 2.1844]], 240 241 [[ 0.1276, 0.1874, -0.6334], 242 [-1.9682, -0.5340, 0.7483]]]), 243 size=(5, 5, 2, 3), nnz=3, layout=torch.sparse_coo) 244 245 # when sum over only part of sparse_dims, return a sparse tensor 246 >>> torch.sparse.sum(S, [1, 3]) 247 tensor(indices=tensor([[0, 2, 3]]), 248 values=tensor([[-1.4512, 0.4073], 249 [-0.8901, 0.2017], 250 [-0.3183, -1.7539]]), 251 size=(5, 2), nnz=3, layout=torch.sparse_coo) 252 253 # when sum over all sparse dim, return a dense tensor 254 # with summed dims squeezed 255 >>> torch.sparse.sum(S, [0, 1, 3]) 256 tensor([-2.6596, -1.1450]) 257 """ 258 if dtype is None: 259 if dim is not None: 260 return torch._sparse_sum(input, dim) 261 else: 262 return torch._sparse_sum(input) 263 else: 264 if dim is not None: 265 return torch._sparse_sum(input, dim, dtype=dtype) 266 else: 267 return torch._sparse_sum(input, dtype=dtype) 268 269 270softmax = _add_docstr( 271 _sparse._sparse_softmax, 272 r""" 273sparse.softmax(input, dim, *, dtype=None) -> Tensor 274 275Applies a softmax function. 276 277Softmax is defined as: 278 279:math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}` 280 281where :math:`i, j` run over sparse tensor indices and unspecified 282entries are ignores. This is equivalent to defining unspecified 283entries as negative infinity so that :math:`exp(x_k) = 0` when the 284entry with index :math:`k` has not specified. 285 286It is applied to all slices along `dim`, and will re-scale them so 287that the elements lie in the range `[0, 1]` and sum to 1. 288 289Args: 290 input (Tensor): input 291 dim (int): A dimension along which softmax will be computed. 292 dtype (:class:`torch.dtype`, optional): the desired data type 293 of returned tensor. If specified, the input tensor is 294 casted to :attr:`dtype` before the operation is 295 performed. This is useful for preventing data type 296 overflows. Default: None 297""", 298) 299 300 301spsolve = _add_docstr( 302 _sparse._spsolve, 303 r""" 304sparse.spsolve(input, other, *, left=True) -> Tensor 305 306Computes the solution of a square system of linear equations with 307a unique solution. Its purpose is similar to :func:`torch.linalg.solve`, 308except that the system is defined by a sparse CSR matrix with layout 309`sparse_csr`. 310 311Args: 312 input (Tensor): a sparse CSR matrix of shape `(n, n)` representing the 313 coefficients of the linear system. 314 other (Tensor): a dense matrix of shape `(n, )` representing the right-hand 315 side of the linear system. 316 left (bool, optional): whether to solve the system for `input @ out = other` 317 (default) or `out @ input = other`. Only `left=True` is supported. 318""", 319) 320 321log_softmax = _add_docstr( 322 _sparse._sparse_log_softmax, 323 r""" 324sparse.log_softmax(input, dim, *, dtype=None) -> Tensor 325 326Applies a softmax function followed by logarithm. 327 328See :class:`~torch.sparse.softmax` for more details. 329 330Args: 331 input (Tensor): input 332 dim (int): A dimension along which softmax will be computed. 333 dtype (:class:`torch.dtype`, optional): the desired data type 334 of returned tensor. If specified, the input tensor is 335 casted to :attr:`dtype` before the operation is 336 performed. This is useful for preventing data type 337 overflows. Default: None 338""", 339) 340 341 342spdiags = _add_docstr( 343 _sparse._spdiags, 344 r""" 345sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor 346 347Creates a sparse 2D tensor by placing the values from rows of 348:attr:`diagonals` along specified diagonals of the output 349 350The :attr:`offsets` tensor controls which diagonals are set. 351 352- If :attr:`offsets[i]` = 0, it is the main diagonal 353- If :attr:`offsets[i]` < 0, it is below the main diagonal 354- If :attr:`offsets[i]` > 0, it is above the main diagonal 355 356The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`, 357and an offset may not be repeated. 358 359Args: 360 diagonals (Tensor): Matrix storing diagonals row-wise 361 offsets (Tensor): The diagonals to be set, stored as a vector 362 shape (2-tuple of ints): The desired shape of the result 363Keyword args: 364 layout (:class:`torch.layout`, optional): The desired layout of the 365 returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr`` 366 are supported. Default: ``torch.sparse_coo`` 367 368Examples: 369 370Set the main and first two lower diagonals of a matrix:: 371 372 >>> diags = torch.arange(9).reshape(3, 3) 373 >>> diags 374 tensor([[0, 1, 2], 375 [3, 4, 5], 376 [6, 7, 8]]) 377 >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3)) 378 >>> s 379 tensor(indices=tensor([[0, 1, 2, 1, 2, 2], 380 [0, 1, 2, 0, 1, 0]]), 381 values=tensor([0, 1, 2, 3, 4, 6]), 382 size=(3, 3), nnz=6, layout=torch.sparse_coo) 383 >>> s.to_dense() 384 tensor([[0, 0, 0], 385 [3, 1, 0], 386 [6, 4, 2]]) 387 388 389Change the output layout:: 390 391 >>> diags = torch.arange(9).reshape(3, 3) 392 >>> diags 393 tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8]) 394 >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr) 395 >>> s 396 tensor(crow_indices=tensor([0, 1, 3, 6]), 397 col_indices=tensor([0, 0, 1, 0, 1, 2]), 398 values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6, 399 layout=torch.sparse_csr) 400 >>> s.to_dense() 401 tensor([[0, 0, 0], 402 [3, 1, 0], 403 [6, 4, 2]]) 404 405Set partial diagonals of a large output:: 406 407 >>> diags = torch.tensor([[1, 2], [3, 4]]) 408 >>> offsets = torch.tensor([0, -1]) 409 >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense() 410 tensor([[1, 0, 0, 0, 0], 411 [3, 2, 0, 0, 0], 412 [0, 4, 0, 0, 0], 413 [0, 0, 0, 0, 0], 414 [0, 0, 0, 0, 0]]) 415 416.. note:: 417 418 When setting the values along a given diagonal the index into the diagonal 419 and the index into the row of :attr:`diagonals` is taken as the 420 column index in the output. This has the effect that when setting a diagonal 421 with a positive offset `k` the first value along that diagonal will be 422 the value in position `k` of the row of :attr:`diagonals` 423 424Specifying a positive offset:: 425 426 >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) 427 >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense() 428 tensor([[1, 2, 3, 0, 0], 429 [0, 2, 3, 0, 0], 430 [0, 0, 3, 0, 0], 431 [0, 0, 0, 0, 0], 432 [0, 0, 0, 0, 0]]) 433""", 434) 435 436 437class check_sparse_tensor_invariants: 438 """A tool to control checking sparse tensor invariants. 439 440 The following options exists to manage sparsr tensor invariants 441 checking in sparse tensor construction: 442 443 1. Using a context manager: 444 445 .. code:: python 446 447 with torch.sparse.check_sparse_tensor_invariants(): 448 run_my_model() 449 450 2. Using a procedural approach: 451 452 .. code:: python 453 454 prev_checks_enabled = torch.sparse.check_sparse_tensor_invariants.is_enabled() 455 torch.sparse.check_sparse_tensor_invariants.enable() 456 457 run_my_model() 458 459 if not prev_checks_enabled: 460 torch.sparse.check_sparse_tensor_invariants.disable() 461 462 3. Using function decoration: 463 464 .. code:: python 465 466 @torch.sparse.check_sparse_tensor_invariants() 467 def run_my_model(): 468 ... 469 470 run_my_model() 471 472 4. Using ``check_invariants`` keyword argument in sparse tensor constructor call. 473 For example: 474 475 >>> torch.sparse_csr_tensor([0, 1, 3], [0, 1], [1, 2], check_invariants=True) 476 Traceback (most recent call last): 477 File "<stdin>", line 1, in <module> 478 RuntimeError: `crow_indices[..., -1] == nnz` is not satisfied. 479 """ 480 481 @staticmethod 482 def is_enabled(): 483 r"""Return True if the sparse tensor invariants checking is enabled. 484 485 .. note:: 486 487 Use :func:`torch.sparse.check_sparse_tensor_invariants.enable` or 488 :func:`torch.sparse.check_sparse_tensor_invariants.disable` to 489 manage the state of the sparse tensor invariants checks. 490 """ 491 return torch._C._check_sparse_tensor_invariants() 492 493 @staticmethod 494 def enable(): 495 r"""Enable sparse tensor invariants checking in sparse tensor constructors. 496 497 .. note:: 498 499 By default, the sparse tensor invariants checks are disabled. Use 500 :func:`torch.sparse.check_sparse_tensor_invariants.is_enabled` to 501 retrieve the current state of sparse tensor invariants checking. 502 503 .. note:: 504 505 The sparse tensor invariants check flag is effective to all sparse 506 tensor constructors, both in Python and ATen. 507 508 The flag can be locally overridden by the ``check_invariants`` 509 optional argument of the sparse tensor constructor functions. 510 """ 511 torch._C._set_check_sparse_tensor_invariants(True) 512 513 @staticmethod 514 def disable(): 515 r"""Disable sparse tensor invariants checking in sparse tensor constructors. 516 517 See :func:`torch.sparse.check_sparse_tensor_invariants.enable` for more information. 518 """ 519 torch._C._set_check_sparse_tensor_invariants(False) 520 521 # context manager support 522 def __init__(self, enable=True): 523 self.state = enable 524 self.saved_state: Optional[bool] = None 525 526 def __enter__(self): 527 if self.saved_state is not None: 528 raise RuntimeError( 529 "This context manager instance is already activated." 530 " Use a different context manager instance for context nesting." 531 ) 532 self.saved_state = self.is_enabled() 533 torch._C._set_check_sparse_tensor_invariants(self.state) 534 535 def __exit__(self, type, value, traceback): 536 assert self.saved_state is not None 537 torch._C._set_check_sparse_tensor_invariants(self.saved_state) 538 self.saved_state = None 539 540 # decorator support 541 def __call__(self, mth): 542 def test_mth(*args, **kwargs): 543 with type(self)(self.state): 544 return mth(*args, **kwargs) 545 546 return test_mth 547 548 549def as_sparse_gradcheck(gradcheck): 550 """Decorate function, to extend gradcheck for sparse tensors. 551 552 Decorator for torch.autograd.gradcheck or its functools.partial 553 variants that extends the gradcheck function with support to input 554 functions that operate on or/and return sparse tensors. 555 556 The specified gradcheck function itself is guaranteed to operate 557 on strided tensors only. 558 559 For example: 560 561 >>> gradcheck = torch.sparse.as_sparse_gradcheck(torch.autograd.gradcheck) 562 >>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse_coo().requires_grad_(True) 563 >>> gradcheck(lambda x: x.to_sparse_csr(), x) 564 True 565 """ 566 567 def gradcheck_with_sparse_support(func, inputs, **kwargs): 568 """ 569 Create gradcheck with support for sparse tensors. 570 571 Same as :func:`torch.autograd.gradcheck` but with sparse tensors inputs and outputs support. 572 """ 573 masked = kwargs.pop("masked", False) 574 sparse_layouts = { 575 torch.sparse_coo, 576 torch.sparse_csr, 577 torch.sparse_csc, 578 torch.sparse_bsr, 579 torch.sparse_bsc, 580 } 581 sparse_compressed_layouts = { 582 torch.sparse_csr, 583 torch.sparse_csc, 584 torch.sparse_bsr, 585 torch.sparse_bsc, 586 } 587 sparse_block_layouts = {torch.sparse_bsr, torch.sparse_bsc} 588 STRIDED_REPRESENTATION = "__STRIDED_REPRESENTATION__" 589 590 def convert_to_strided_representation(args): 591 """Convert differentiable non-strided tensors to a representation containing differentiable strided tensors.""" 592 if not isinstance(args, (list, tuple)): 593 args = (args,) 594 new_args: List[Any] = [] 595 for obj in args: 596 if ( 597 isinstance(obj, torch.Tensor) 598 and obj.requires_grad 599 and obj.layout in sparse_layouts 600 ): 601 d = dict(layout=obj.layout, shape=obj.shape) 602 if not masked: 603 # Materialize unspecified elements with zero values 604 batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim() 605 blocksize = ( 606 obj.values().shape[batch_dim + 1 : batch_dim + 3] 607 if obj.layout in sparse_block_layouts 608 else None 609 ) 610 full_mask = torch.ones( 611 obj.shape, device=obj.device, dtype=torch.bool 612 ).to_sparse( 613 layout=obj.layout, 614 blocksize=blocksize, 615 dense_dim=obj.dense_dim(), 616 ) 617 obj = obj.to_dense().sparse_mask(full_mask) 618 if obj.layout is torch.sparse_coo: 619 d.update( 620 indices=obj._indices(), is_coalesced=obj.is_coalesced() 621 ) 622 values = obj._values() 623 elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: 624 d.update( 625 compressed_indices=obj.crow_indices(), 626 plain_indices=obj.col_indices(), 627 ) 628 values = obj.values() 629 else: 630 d.update( 631 compressed_indices=obj.ccol_indices(), 632 plain_indices=obj.row_indices(), 633 ) 634 values = obj.values() 635 new_args.extend( 636 (STRIDED_REPRESENTATION, d, values.requires_grad_(True)) 637 ) 638 else: 639 new_args.append(obj) 640 return tuple(new_args) 641 642 def restore_from_strided_representation(args): 643 """Restore non-strided differentiable tensosr from their strided representations.""" 644 new_args = [] 645 args = list(args) 646 while args: 647 a = args.pop(0) 648 if a == STRIDED_REPRESENTATION: 649 d, values = args.pop(0), args.pop(0) 650 if d["layout"] is torch.sparse_coo: 651 a = torch.sparse_coo_tensor( 652 d["indices"], 653 values, 654 size=d["shape"], 655 is_coalesced=d["is_coalesced"], 656 ) 657 elif d["layout"] in sparse_compressed_layouts: 658 a = torch.sparse_compressed_tensor( 659 d["compressed_indices"], 660 d["plain_indices"], 661 values, 662 size=d["shape"], 663 layout=d["layout"], 664 ) 665 else: 666 raise NotImplementedError( 667 f'conversion of {d["layout"]} strided representation to tensor' 668 ) 669 new_args.append(a) 670 return tuple(new_args) 671 672 def func_wrapper(*args, **kwargs): 673 restored_args = restore_from_strided_representation(args) 674 675 # convert differentiable output sparse tensors to strided 676 # tensors: 677 outputs = func(*restored_args, **kwargs) 678 679 strided_outputs = ( 680 tuple(outputs) if isinstance(outputs, (list, tuple)) else (outputs,) 681 ) 682 strided_outputs = tuple( 683 ( 684 o.to_dense(masked_grad=masked) 685 if isinstance(o, torch.Tensor) 686 and o.requires_grad 687 and o.layout in sparse_layouts 688 else o 689 ) 690 for o in strided_outputs 691 ) 692 693 return ( 694 strided_outputs 695 if isinstance(outputs, (list, tuple)) 696 else strided_outputs[0] 697 ) 698 699 args = (func_wrapper, convert_to_strided_representation(inputs)) 700 701 return gradcheck(*args, **kwargs) 702 703 return gradcheck_with_sparse_support 704