xref: /aosp_15_r20/external/pytorch/test/test_sparse_csr.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: sparse"]
2
3import torch
4import random
5import io
6import itertools
7import unittest
8import functools
9from contextlib import redirect_stderr
10from torch.testing import make_tensor, FileCheck
11from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC
12from torch.testing._internal.common_utils import \
13    (TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase,
14     run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU,
15     suppress_warnings)
16from torch.testing._internal.common_device_type import \
17    (ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric,
18     precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan,
19     largeTensorTest)
20from torch.testing._internal.common_methods_invocations import \
21    (op_db, sparse_csr_unary_ufuncs, ReductionOpInfo)
22from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CUDA
23from torch.testing._internal.common_dtype import (
24    floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
25    all_types_and_complex, floating_and_complex_types_and)
26from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
27from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
28from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
29import operator
30
31if TEST_SCIPY:
32    import scipy.sparse as sp
33
34if TEST_NUMPY:
35    import numpy as np
36# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
37# sharding on sandcastle. This line silences flake warnings
38load_tests = load_tests
39
40no_mkl_sparse = IS_WINDOWS or not TEST_MKL
41
42def _check_cusparse_triangular_solve_available():
43    version = _get_torch_cuda_version()
44    # cusparseSpSM was added in 11.3.1 but we don't have access to patch version
45    min_supported_version = (11, 4)
46    return version >= min_supported_version
47
48def _check_cusparse_spgemm_available():
49    # cusparseSpGEMM was added in 11.0
50    return not TEST_WITH_ROCM
51
52def _check_cusparse_sddmm_available():
53    if TEST_WITH_ROCM:
54        return True
55    version = _get_torch_cuda_version()
56    # cusparseSDDMM was added in 11.2.1 but we don't have access to patch version
57    min_supported_version = (11, 3)
58    return version >= min_supported_version
59
60_sparse_csr_ops = list(filter(lambda op: op.supports_sparse_csr, op_db))
61_sparse_compressed_ops = list(filter(lambda op: (op.supports_sparse_csr or op.supports_sparse_csc
62                                                 or op.supports_sparse_bsr or op.supports_sparse_bsc), op_db))
63binary_functions_with_dense_output = ['mm', 'mv', ]
64binary_ops_with_dense_output = list(filter(lambda op: op.name in binary_functions_with_dense_output, op_db))
65
66UNARY_EWISE_CSR_ALLOW_AUTOGRAD = [
67    'abs',
68    'conj_physical',
69    'deg2rad',
70    'neg',
71    'positive',
72    'frac',
73    'nn.functional.relu',
74    'log1p',
75    'rad2deg'
76]
77
78# This should be just an import from test_linalg instead of code duplication
79# but https://github.com/pytorch/pytorch/pull/63511#discussion_r733989701
80def _test_addmm_addmv(
81    test_case,
82    f,
83    t,
84    m,
85    v,
86    *,
87    alpha=None,
88    beta=None,
89    transpose_out=False,
90    layout=torch.strided,
91    mode=None
92):
93    """
94    Unified test for checking `f(t, m, v, alpha=alpha, beta=beta)` computation,
95    where f is `torch.addmv` or `torch.addmm`.
96    `transpose_out` controls whether the out argument is in column-major order.
97    `layout` controls whether `m` is converted to specified layout or not.
98    Custom behaviour is implemented only for torch.sparse_csr layout.
99    """
100    dtype = t.dtype
101    numpy_dtype = dtype
102    if dtype in {torch.bfloat16}:
103        numpy_dtype = torch.float
104    if dtype.is_complex:
105        alpha = 0.9 + 0.3j if alpha is None else alpha
106        beta = 0.5 + 0.6j if beta is None else beta
107    else:
108        alpha = 1.2 if alpha is None else alpha
109        beta = 0.8 if beta is None else beta
110
111    def convert_layout(mat):
112        if layout == torch.sparse_csr:
113            return mat.to_sparse_csr()
114        elif layout == torch.sparse_csc:
115            return mat.to_sparse_csc()
116        else:
117            assert mat.layout == layout
118            return mat
119
120    if mode == "all_sparse":
121        res1 = f(*map(convert_layout, (t, m, v)), alpha=alpha, beta=beta)
122        test_case.assertEqual(res1.layout, layout)
123        res1 = res1.to_dense()
124    elif mode == "dense_result":
125        res1 = f(t, convert_layout(m), convert_layout(v), alpha=alpha, beta=beta)
126    else:
127        res1 = f(t, convert_layout(m), v, alpha=alpha, beta=beta)
128    res2 = torch.full_like(res1, float('nan'))
129    if transpose_out:
130        res2 = res2.t().clone(memory_format=torch.contiguous_format).t()
131    f(t, convert_layout(m), v, alpha=alpha, beta=beta, out=res2)
132    res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy())
133    if beta != 0:
134        res3 += (beta * t).to(numpy_dtype).cpu().numpy()
135    res3 = torch.from_numpy(res3).to(dtype)
136    test_case.assertEqual(res1, res2)
137    test_case.assertEqual(res1, res3)
138
139
140class TestSparseCSRSampler(TestCase):
141
142    def test_make_crow_indices(self):
143        # Here we test the correctness of the crow_indices algorithm
144        # and testing it on CPU and with int32 dtype will be
145        # sufficient.
146        device = torch.device('cpu')
147        index_dtype = torch.int32
148        for n_rows in range(1, 10):
149            for n_cols in range(1, 10):
150                for nnz in range(0, n_rows * n_cols + 1):
151                    crow_indices = self._make_crow_indices(
152                        n_rows, n_cols, nnz,
153                        device=device, dtype=index_dtype)
154                    self.assertEqual(len(crow_indices), n_rows + 1)
155                    counts = crow_indices[1:] - crow_indices[:-1]
156                    self.assertEqual(counts.sum(), nnz)
157                    self.assertGreaterEqual(counts.min(), 0)
158                    self.assertLessEqual(counts.max(), n_cols)
159
160
161def all_sparse_compressed_layouts(test_name='layout'):
162    return parametrize(test_name, [
163        subtest(torch.sparse_csr, name='SparseCSR'),
164        subtest(torch.sparse_csc, name='SparseCSC'),
165        subtest(torch.sparse_bsr, name='SparseBSR'),
166        subtest(torch.sparse_bsc, name='SparseBSC')])
167
168
169def sparse_compressed_nonblock_layouts(test_name='layout'):
170    return parametrize(test_name, [
171        subtest(torch.sparse_csr, name='SparseCSR'),
172        subtest(torch.sparse_csc, name='SparseCSC')])
173
174
175sparse_compressed_indices_methods = {
176    torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
177    torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
178    torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
179    torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
180}
181
182
183def batched_nonbatched(test_name='batched'):
184    return parametrize(test_name, [
185        subtest(True, name="Batched"),
186        subtest(False, name="NonBatched")
187    ])
188
189
190def hybrid_nonhybrid(test_name='hybrid'):
191    return parametrize(test_name, [
192        subtest(True, name="Hybrid"),
193        subtest(False, name="NonHybrid")
194    ])
195
196
197class TestSparseCompressed(TestCase):
198    """Testing sparse compressed (CSR, CSC, BSR, BSC) tensor generic features.
199    """
200
201    def genTensor(self, size, nnz, *, layout, device=None, dtype=torch.float, index_dtype=torch.int64):
202        if device is None:
203            device = self.device_type
204        return self.genSparseCompressedTensor(size, nnz, device=device, dtype=dtype, index_dtype=index_dtype, layout=layout)
205
206    @all_sparse_compressed_layouts()
207    @onlyCPU
208    def test_layout(self, layout):
209        self.assertIn(str(layout), {'torch.sparse_csr', 'torch.sparse_csc', 'torch.sparse_bsr', 'torch.sparse_bsc'})
210        self.assertEqual(type(layout), torch.layout)
211
212    @parametrize('shape_and_device_inference', [subtest(False, name='_'), subtest(True, name='shape_and_device_inference')])
213    @parametrize('use_factory_function', [subtest(False, name='_'), subtest(True, name='factory')])
214    @parametrize('input_kind', [subtest('tensor', name='from_tensor'), subtest('list', name='from_list')])
215    @all_sparse_compressed_layouts()
216    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
217    def test_sparse_compressed_constructor(self, layout, device, dtype,
218                                           use_factory_function, shape_and_device_inference, input_kind):
219        if input_kind == 'list' and shape_and_device_inference:
220            if torch.device(device).type == 'cuda':
221                # list inputs to factory/constructor function without
222                # specifying device will result a sparse compressed tensor
223                # on CPU. So, skip testing against cuda device as unused.
224                self.skipTest("nothing to test")
225            if dtype not in {torch.float32, torch.complex64, torch.int64, torch.bool}:
226                self.skipTest("dtype not supported with list values")
227
228        expected_devices = [torch.device(device)]
229        if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2 and not shape_and_device_inference:
230            expected_devices.append(torch.device('cuda:1'))
231
232        factory_function = {
233            torch.sparse_csr: torch.sparse_csr_tensor,
234            torch.sparse_csc: torch.sparse_csc_tensor,
235            torch.sparse_bsr: torch.sparse_bsr_tensor,
236            torch.sparse_bsc: torch.sparse_bsc_tensor,
237        }[layout]
238        compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
239        if input_kind == 'list':
240            index_dtypes = [torch.int64]
241        else:
242            index_dtypes = [torch.int32, torch.int64]
243        if dtype.is_floating_point or dtype.is_complex:
244            requires_grad_lst = [False, True]
245        else:
246            requires_grad_lst = [False]
247        for index_dtype in index_dtypes:
248            for expected_device in expected_devices:
249                for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
250                        layout, device=expected_device, dtype=dtype, index_dtype=index_dtype,
251                        # skip zero-sized tensors for list inputs:
252                        enable_zero_sized=input_kind != 'list',
253                        output_tensor=False):
254                    size = kwargs['size']
255                    if shape_and_device_inference and 0 in size:
256                        # skip shape inference for zero-sized tensor
257                        # inputs because (i) the shape determined from
258                        # an empty list is ambiguous, and (ii) the
259                        # size of the plain dimension defined as
260                        # max(plain_indices) is undefined if
261                        # plain_indices has no values
262                        continue
263                    compressed_indices_expect = compressed_indices
264                    plain_indices_expect = plain_indices
265                    values_expect = values
266
267                    if input_kind == 'list':
268                        compressed_indices = compressed_indices.tolist()
269                        plain_indices = plain_indices.tolist()
270                        values = values.tolist()
271
272                    for requires_grad in requires_grad_lst:
273                        if use_factory_function:
274                            if shape_and_device_inference:
275                                sparse = factory_function(
276                                    compressed_indices, plain_indices, values, requires_grad=requires_grad)
277                            else:
278                                sparse = factory_function(
279                                    compressed_indices, plain_indices, values, size,
280                                    dtype=dtype, device=expected_device, requires_grad=requires_grad)
281                        else:
282                            if shape_and_device_inference:
283                                sparse = torch.sparse_compressed_tensor(
284                                    compressed_indices, plain_indices, values,
285                                    layout=layout, requires_grad=requires_grad)
286                            else:
287                                sparse = torch.sparse_compressed_tensor(
288                                    compressed_indices, plain_indices, values, size,
289                                    dtype=dtype, layout=layout, device=expected_device, requires_grad=requires_grad)
290
291                        self.assertEqual(layout, sparse.layout)
292                        self.assertEqual(size, sparse.shape)
293                        self.assertEqual(compressed_indices_expect, compressed_indices_mth(sparse))
294                        self.assertEqual(plain_indices_expect, plain_indices_mth(sparse))
295                        self.assertEqual(values_expect, sparse.values())
296                        self.assertEqual(sparse.device, sparse.values().device)
297                        self.assertEqual(sparse.device, expected_device)
298                        self.assertEqual(sparse.values().requires_grad, requires_grad)
299                        self.assertEqual(sparse.requires_grad, requires_grad)
300                        self.assertFalse(compressed_indices_mth(sparse).requires_grad)
301                        self.assertFalse(plain_indices_mth(sparse).requires_grad)
302
303    @skipMeta
304    @sparse_compressed_nonblock_layouts()
305    @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
306    def test_empty(self, layout, device, dtype):
307        ns = [5, 2, 0]
308        batch_shapes = [(), (2,), (2, 3)]
309        compressed_dim = {
310            torch.sparse_csr: -2,
311            torch.sparse_csc: -1,
312        }[layout]
313        compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
314        for m, n, b in itertools.product(ns, ns, batch_shapes):
315            shape = (*b, m, n)
316            with torch.sparse.check_sparse_tensor_invariants(enable=False):
317                # torch.empty may return invalid sparse compressed tensors
318                result = torch.empty(shape, dtype=dtype, device=device, layout=layout)
319            self.assertEqual(result.shape, shape)
320            self.assertEqual(result.dtype, dtype)
321            self.assertEqual(result.device, torch.device(device))
322            self.assertEqual(result.layout, layout)
323            self.assertEqual(compressed_indices_mth(result).shape, (*b, shape[compressed_dim] + 1,))
324            self.assertEqual(plain_indices_mth(result).shape, (*b, 0,))
325            self.assertEqual(result.values().shape, (*b, 0,))
326            self.assertEqual(result._nnz(), 0)
327            self.assertEqual(compressed_indices_mth(result).device, torch.device(device))
328            self.assertEqual(plain_indices_mth(result).device, torch.device(device))
329            self.assertEqual(result.values().device, torch.device(device))
330            self.assertEqual(compressed_indices_mth(result).dtype, torch.int64)
331            self.assertEqual(plain_indices_mth(result).dtype, torch.int64)
332            self.assertEqual(result.values().dtype, dtype)
333
334    @skipMeta
335    @sparse_compressed_nonblock_layouts()
336    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
337    def test_empty_errors(self, layout, device, dtype):
338        with self.assertRaisesRegex(RuntimeError,
339                                    "torch.empty: Only batched sparse compressed \\(non-block\\) tensors are supported"
340                                    ", but got size"):
341            torch.empty((5,), dtype=dtype, device=device, layout=layout)
342
343    @skipMeta
344    @all_sparse_compressed_layouts()
345    @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
346    def test_sparse_compressed_tensor_with_dims(self, layout, device, dtype):
347
348        def get_sparse_compressed_tensor_properties(s):
349            if layout in {torch.sparse_csr, torch.sparse_bsr}:
350                compressed_indices, plain_indices = s.crow_indices(), s.col_indices()
351            else:
352                compressed_indices, plain_indices = s.ccol_indices(), s.row_indices()
353            values = s.values()
354            return dict(shape=s.shape, dtype=s.dtype, device=s.device, nnz=s._nnz(), layout=s.layout,
355                        compressed_indices_shape=compressed_indices.shape,
356                        compressed_indices_dtype=compressed_indices.dtype,
357                        compressed_indices_device=compressed_indices.device,
358                        plain_indices_shape=plain_indices.shape,
359                        plain_indices_dtype=plain_indices.dtype,
360                        plain_indices_device=plain_indices.device,
361                        values_shape=values.shape,
362                        values_dtype=values.dtype,
363                        values_device=values.device)
364
365        for index_dtype in [torch.int32, torch.int64]:
366            for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
367                dense_dim = t.dense_dim()
368                sparse_dim = t.sparse_dim()
369                batch_dim = t.ndim - sparse_dim - dense_dim
370                nnz = t.values().shape[batch_dim]
371                if layout in {torch.sparse_bsr, torch.sparse_bsc}:
372                    blocksize = t.values().shape[batch_dim + 1: batch_dim + 1 + sparse_dim]
373                else:
374                    blocksize = ()
375
376                e = torch.ops.aten._sparse_compressed_tensor_with_dims(nnz, dense_dim, t.shape, blocksize, index_dtype,
377                                                                       dtype=dtype, layout=layout, device=device)
378
379                e_prop, t_prop = get_sparse_compressed_tensor_properties(e), get_sparse_compressed_tensor_properties(t)
380                for k, v in e_prop.items():
381                    self.assertEqual(v, t_prop[k], lambda msg: f'{msg} when comparing {k}, expected {t_prop[k]}, got {v}')
382
383    @skipMeta
384    @all_sparse_compressed_layouts()
385    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
386    def test_clone(self, layout, device, dtype):
387        for sparse in self.generate_simple_inputs(
388                layout, device=device, dtype=dtype, index_dtype=torch.int32):
389            cloned_sparse = sparse.clone()
390            self.assertEqual(sparse, cloned_sparse)
391
392    @all_sparse_compressed_layouts()
393    def test_print(self, layout, device):
394        compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
395        printed = []
396        for enable_hybrid in [False, True]:
397            # using local patterns for test_print stability
398            patterns = [
399                # 2 x 3 batch of 3 x 2 tensors, trivial blocksize, non-hybrid/hybrid:
400                ([[[[1, 2, 0],
401                    [1, 0, 3]],
402                   [[1, 2, 3],
403                    [1, 0, 0]],
404                   [[1, 0, 0],
405                    [1, 2, 3]]],
406                  [[[0, 2, 0],
407                    [1, 2, 3]],
408                   [[1, 0, 3],
409                    [1, 2, 0]],
410                   [[1, 2, 3],
411                    [0, 2, 0]]]], [(2, 1)], [(), (4,)] if enable_hybrid else [()]),
412                # tensor with non-trivial blocksize, non-hybrid/hybrid:
413                ([[0, 1, 0, 2, 0, 2],
414                  [0, 1, 0, 0, 2, 0],
415                  [3, 3, 3, 0, 0, 0],
416                  [0, 0, 0, 0, 0, 0],
417                  [0, 5, 0, 6, 6, 6],
418                  [5, 0, 5, 6, 6, 6],
419                  [0, 0, 0, 0, 8, 8],
420                  [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 2)] if enable_hybrid else [()]),
421            ]
422            for index_dtype in [torch.int32, torch.int64]:
423                for dtype in [torch.float32, torch.float64]:
424                    for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
425                            layout, device=device, dtype=dtype, index_dtype=index_dtype, enable_hybrid=enable_hybrid,
426                            enable_non_contiguous_indices=False, enable_non_contiguous_values=False,
427                            enable_zero_sized=False, output_tensor=False, patterns=patterns):
428                        size = tuple(kwargs['size'])
429                        block_ndim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0
430                        base_ndim = 2
431                        batch_ndim = compressed_indices.dim() - 1
432                        dense_ndim = values.dim() - batch_ndim - block_ndim - 1
433                        if enable_hybrid and dense_ndim == 0:
434                            # non-hybrid cases are covered by the enable_hybrid==False loop
435                            continue
436                        batchsize = size[:batch_ndim]
437                        basesize = size[batch_ndim:batch_ndim + base_ndim]
438                        densesize = size[batch_ndim + base_ndim:]
439                        assert len(densesize) == dense_ndim
440                        printed.append(f"########## {dtype}/{index_dtype}/size={batchsize}+{basesize}+{densesize} ##########")
441                        x = torch.sparse_compressed_tensor(compressed_indices,
442                                                           plain_indices,
443                                                           values, size, dtype=dtype, layout=layout, device=device)
444                        printed.append("# sparse tensor")
445                        printed.append(str(x))
446                        printed.append(f"# _{compressed_indices_mth.__name__}")
447                        printed.append(str(compressed_indices_mth(x)))
448                        printed.append(f"# _{plain_indices_mth.__name__}")
449                        printed.append(str(plain_indices_mth(x)))
450                        printed.append("# _values")
451                        printed.append(str(x.values()))
452                        printed.append('')
453                    printed.append('')
454        orig_maxDiff = self.maxDiff
455        self.maxDiff = None
456        try:
457            self.assertExpected('\n'.join(printed))
458            self.maxDiff = orig_maxDiff
459        except Exception:
460            self.maxDiff = orig_maxDiff
461            raise
462
463    @skipMeta
464    @all_sparse_compressed_layouts()
465    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
466    def test_copy(self, layout, device, dtype):
467
468        def run_test(shape, blocksize, nnz, index_type):
469            a = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
470                                               index_dtype=index_dtype, blocksize=blocksize)
471            b = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device,
472                                               index_dtype=index_dtype, blocksize=blocksize)
473
474            a.copy_(b)
475
476            self.assertEqual(a, b)
477
478        ns = [(9, 3), (2, 1), (0, 0)]  # (number of dimensions, the corresponding block size)
479        batch_shapes = [(), (2,), (2, 3)]
480        for ((m, bm), (n, bn), b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
481            blocksize = (bm, bn) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
482            run_test((*b, m, n), blocksize, 0, index_dtype)
483            run_test((*b, m, n), blocksize, m * n, index_dtype)
484
485    @skipMeta
486    @all_sparse_compressed_layouts()
487    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
488    def test_copy_errors(self, layout, device, dtype):
489        blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
490        nnz = 6 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 1
491        shape1 = (2 * 6, 3 * 6) if layout in {torch.sparse_bsr, torch.sparse_bsc} else (2, 3)
492        for index_dtype in [torch.int32, torch.int64]:
493            a = self.genSparseCompressedTensor(shape1, 0, dtype=dtype, layout=layout, device=device,
494                                               index_dtype=index_dtype, blocksize=blocksize)
495
496            with self.assertRaisesRegex(RuntimeError,
497                                        "copy of sparse compressed tensors having different layouts is not supported."):
498                a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
499
500            b = self.genSparseCompressedTensor(shape1, nnz, dtype=dtype, layout=layout, device=device,
501                                               index_dtype=index_dtype, blocksize=blocksize)
502            assert a._nnz() != b._nnz(), (a._nnz(), b._nnz())
503            with self.assertRaisesRegex(RuntimeError,
504                                        "only sparse compressed tensors with the same number of specified elements are supported."):
505                a.copy_(b)
506
507            shape2 = tuple(reversed(shape1))
508            c = self.genSparseCompressedTensor(shape2, nnz, dtype=dtype, layout=layout, device=device,
509                                               index_dtype=index_dtype, blocksize=blocksize)
510            with self.assertRaisesRegex(
511                    RuntimeError,
512                    "expected shapes of self and src to match along dimension"):
513                b.copy_(c)
514
515            if blocksize:
516                blocksize1 = tuple(reversed(blocksize))
517                d = self.genSparseCompressedTensor(shape1, nnz, dtype=dtype, layout=layout, device=device,
518                                                   index_dtype=index_dtype, blocksize=blocksize1)
519                with self.assertRaisesRegex(RuntimeError,
520                                            "copy of sparse compressed tensors having different block sizes is not supported"):
521                    b.copy_(d)
522
523    def _smallest_divisor(self, n):
524        for i in range(2, int(n ** 0.5) + 1):
525            if n % i == 0:
526                return i
527        return n
528
529    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
530    @all_sparse_compressed_layouts()
531    @ops(_sparse_compressed_ops)
532    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
533    def test_consistency(self, layout, device, dtype, op):
534        """Checks that the op on a strided and on a sparse tensors will
535        produce the same results.
536        """
537        if not op.supports_sparse_layout(layout):
538            self.skipTest(f"{op.name} does not support input with {layout} layout")
539
540        # FIXME: remove in followup once integer support is landed for segment_reduce
541        if (layout == torch.sparse_csr and not dtype.is_floating_point
542                and op.name in ('masked.mean', 'masked.amax', 'masked.amin')):
543            self.skipTest(f"{op.name} does not support input with {layout} layout and {dtype} dtype")
544
545        require_mask = isinstance(op, ReductionOpInfo) and 'masked.' in op.name
546
547        samples = []
548        for sample in op.sample_inputs(device, dtype):
549            if sample.input.ndim < 2:
550                continue
551            dense_dim = sample.input.ndim - 2
552            blocksize = (tuple(map(self._smallest_divisor, sample.input.shape[:2]))
553                         if layout in {torch.sparse_bsr, torch.sparse_bsc} else None)
554
555            def _to_sparse(x):
556                if isinstance(x, torch.Tensor):
557                    if blocksize is None:
558                        if x.ndim != sample.input.ndim:
559                            return x
560                    elif x.ndim != sample.input.ndim + 2 or x.shape[-3] % blocksize[0] or x.shape[-2] % blocksize[1]:
561                        return x
562                    return x.clone().to_sparse(layout=layout, blocksize=blocksize, dense_dim=dense_dim)
563                return x
564
565            sparse_sample = sample.transform(_to_sparse)
566            # Some strided samples (with inf, nan elements) appear to share
567            # storage, so we must clone:
568            sample = sample.transform(lambda x: (x.clone() if isinstance(x, torch.Tensor) else x))
569
570            if validate_sample_input_sparse(op, sparse_sample, check_validate=False) is not sparse_sample:
571                # that is, the validation returns the sparse sample
572                # wrapped within ErrorInput instance
573                continue
574            samples.append((sample, sparse_sample))
575
576        # Fail early to prevent silent success with this test
577        if len(samples) == 0:
578            raise ValueError("Expected at least one 2 or higher D tensor in samples.")
579
580        # Re-define atol and rtol for operations that result values
581        # are random (and hence, non-comparable) be we still want to
582        # check the shape, dtype, etc attributes of the results:
583        atol = rtol = None
584        if op.name == 'randn_like':
585            atol = 1e300
586            rtol = 1
587
588        for sample, sparse_sample in samples:
589            expected = op(sample.input, *sample.args, **sample.kwargs)
590            assert torch.is_tensor(expected)
591            output = op(sparse_sample.input, *sparse_sample.args, **sparse_sample.kwargs)
592            assert torch.is_tensor(output)
593            strided_output = output.to_dense()
594            if require_mask and sample.kwargs.get('mask') is not None:
595                output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs)
596                expected.masked_fill_(~output_mask, 0)
597            self.assertEqual(strided_output, expected, atol=atol, rtol=rtol)
598
599    @skipMeta
600    @all_sparse_compressed_layouts()
601    @all_sparse_compressed_layouts('layout2')
602    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
603    def test_empty_like(self, layout, layout2, device, dtype):
604        for sparse in self.generate_simple_inputs(layout):
605            if layout == layout2:
606                result = torch.empty_like(sparse, layout=layout2)
607                compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[result.layout]
608                torch._validate_sparse_compressed_tensor_args(compressed_indices_mth(result),
609                                                              plain_indices_mth(result),
610                                                              result.values(),
611                                                              result.shape,
612                                                              result.layout)
613                self.assertEqual(sparse.shape, result.shape)
614            else:
615                self.assertRaisesRegex(
616                    RuntimeError,
617                    "empty_like with different sparse layout is not supported",
618                    lambda: torch.empty_like(sparse, layout=layout2)
619                )
620
621    @skipMeta
622    @all_sparse_compressed_layouts()
623    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
624    def test_validate(self, layout, device, dtype):
625        def make_zero_batched(t):
626            return torch.empty(*((0,) + t.shape), dtype=t.dtype, device=t.device)
627
628        for index_dtype in [torch.int32, torch.int64]:
629            for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
630                    layout, device=device, dtype=dtype, index_dtype=index_dtype, output_tensor=False):
631                size = kwargs['size']
632                torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout)
633
634                # check empty batch
635                torch._validate_sparse_compressed_tensor_args(
636                    *(make_zero_batched(t) for t in (compressed_indices, plain_indices, values)),
637                    (0,) + size,
638                    layout
639                )
640
641            compressed_indices = torch.tensor([0, 0], dtype=index_dtype)
642            plain_indices = torch.tensor([], dtype=index_dtype)
643            torch._validate_compressed_sparse_indices(layout in {torch.sparse_csr, torch.sparse_bsr},
644                                                      compressed_indices, plain_indices, 1, 1, 0)
645
646    def _generate_invalid_input(self, layout, device):
647        from functools import partial
648
649        def shape(shape, basedim=0):
650            blocksize = (1, 1)
651            if layout is torch.sparse_csc:
652                shape = shape[:basedim] + (shape[basedim + 1], shape[basedim]) + shape[basedim + 2:]
653            elif layout is torch.sparse_bsc:
654                shape = shape[:basedim] + (shape[basedim + 1] * blocksize[1], shape[basedim] * blocksize[0]) + shape[basedim + 2:]
655            elif layout is torch.sparse_bsr:
656                shape = shape[:basedim] + (shape[basedim] * blocksize[0], shape[basedim + 1] * blocksize[1]) + shape[basedim + 2:]
657            return shape
658
659        def values(lst, device=device):
660            if layout in {torch.sparse_bsr, torch.sparse_bsc}:
661                lst = [[[item]] for item in lst]
662            return torch.tensor(lst, device=device)
663
664        tensor = partial(torch.tensor, device=device)
665        values = partial(values, device=device)
666
667        yield ('incontiguous compressed_indices',
668               tensor([0, -1, 2, -1, 4, -1])[::2],
669               tensor([0, 1, 0, 2]),
670               values([1, 2, 3, 4]),
671               shape((2, 3)),
672               'expected compressed_indices to be a contiguous tensor per batch')
673
674        yield ('incontiguous plain_indices',
675               tensor([0, 2, 4]),
676               tensor([0, -1, 1, -1, 0, -1, 2, -1])[::2],
677               values([1, 2, 3, 4]),
678               shape((2, 3)),
679               'expected plain_indices to be a contiguous tensor per batch')
680
681        yield ('0-D compressed_indices',
682               tensor(0),
683               tensor([0, 1, 0, 2]),
684               values([1, 2, 3, 4]),
685               shape((2, 3)),
686               'compressed_indices must have dimensionality >= 1 but got 0')
687
688        yield ('compressed/plain_indices mismatch of dimensionalities',
689               tensor([[0, 2, 4]]),
690               tensor([0, 1, 0, 2]),
691               values([1, 2, 3, 4]),
692               shape((2, 3)),
693               'compressed_indices and plain_indices dimensionalities must be equal but got 2 and 1, respectively')
694
695        if layout in {torch.sparse_csr, torch.sparse_csc}:
696            yield ('indices and values mismatch of dimensionalities',
697                   tensor([[0, 2, 4]]),
698                   tensor([[0, 1, 0, 2]]),
699                   values([1, 2, 3, 4]),
700                   shape((2, 3)),
701                   r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 0\) but got 1')
702        else:
703            yield ('indices and values mismatch of dimensionalities',
704                   tensor([[0, 2, 4]]),
705                   tensor([[0, 1, 0, 2]]),
706                   values([1, 2, 3, 4]),
707                   shape((2, 3)),
708                   r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 2\) but got 3')
709
710        yield ('invalid size',
711               tensor([0, 2, 4]),
712               tensor([0, 1, 0, 2]),
713               values([1, 2, 3, 4]),
714               (2,),
715               r'tensor dimensionality must be sum of batch, base, and dense dimensionalities \(=0 \+ 2 \+ 0\) but got 1')
716
717        yield ('invalid batchsize',
718               tensor([[0, 2, 4]]),
719               tensor([[0, 1, 0, 2]]),
720               values([[1, 2, 3, 4]]),
721               shape((2, 2, 3), 1),
722               r'all batch dimensions of compressed_indices \(=\[1\]\), plain_indices \(=\[1\]\), '
723               r'and values \(=\[1\]\) must be equal to tensor batch dimensions \(=\[2\]\)')
724
725        if layout is torch.sparse_bsr:
726            yield ('invalid blocksize',
727                   tensor([0, 2, 4]),
728                   tensor([0, 1, 0, 2]),
729                   tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]),
730                   shape((2, 3)),
731                   r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape')
732
733        if layout is torch.sparse_bsc:
734            yield ('invalid blocksize',
735                   tensor([0, 2, 4]),
736                   tensor([0, 1, 0, 2]),
737                   tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]),
738                   shape((3, 2)),
739                   r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape')
740
741        yield ('invalid compressed_indices shape',
742               tensor([0, 2, 3, 4]),
743               tensor([0, 1, 0, 2]),
744               values([1, 2, 3, 4]),
745               shape((2, 3)),
746               r'compressed_indices.shape\[-1\] must be equal to the number of compressed_indices_names \+ 1 \(=3\), but got 4')
747
748        yield ('invalid compressed_indices shape',
749               tensor([0, 2, 4]),
750               tensor([0, 1, 0, 1, 2]),
751               values([1, 2, 3, 4]),
752               shape((2, 3)),
753               r'plain_indices.shape\[-1\] must be equal to nnz \(=4\) as defined by values.shape\[0\], but got 5')
754
755        yield ('compressed/plain_indices mismatch of dtype',
756               tensor([0, 2, 4], dtype=torch.int32),
757               tensor([0, 1, 0, 2], dtype=torch.int64),
758               values([1, 2, 3, 4]),
759               shape((2, 3)),
760               r'compressed_indices and plain_indices must have the same dtype, bot got Int and Long, respectively')
761
762        yield ('invalid compressed/plain_indices dtype',
763               tensor([0, 2, 4], dtype=torch.int16),
764               tensor([0, 1, 0, 2], dtype=torch.int16),
765               values([1, 2, 3, 4]),
766               shape((2, 3)),
767               r'compressed_indices and plain_indices dtype must be Int or Long, but got Short')
768
769        # CUDA kernel asserts are not recoverable, so we skip these for now
770        if torch.device(device).type == 'cpu':
771            yield ('invalid compressed_indices[0]',
772                   tensor([1, 2, 4]),
773                   tensor([0, 1, 0, 2]),
774                   values([1, 2, 3, 4]),
775                   shape((2, 3)),
776                   r'`compressed_indices\[..., 0\] == 0` is not satisfied.')
777
778            yield ('invalid compressed_indices[0] when nnz == 0',
779                   tensor([1, 0], dtype=torch.int64),
780                   tensor([], dtype=torch.int64),
781                   values([1])[:0],
782                   shape((1, 1)),
783                   r'`compressed_indices\[..., 0\] == 0` is not satisfied.')
784
785            yield ('invalid compressed_indices[-1]',
786                   tensor([0, 2, 5]),
787                   tensor([0, 1, 0, 2]),
788                   values([1, 2, 3, 4]),
789                   shape((2, 3)),
790                   r'`compressed_indices\[..., -1\] == nnz` is not satisfied.')
791
792            yield ('invalid compressed_indices[-1] when nnz == 0',
793                   tensor([0, 1], dtype=torch.int64),
794                   tensor([], dtype=torch.int64),
795                   values([1])[:0],
796                   shape((1, 1)),
797                   r'`compressed_indices\[..., -1\] == nnz` is not satisfied.')
798
799            yield ('invalid compressed_indices.diff(dim=-1)',
800                   tensor([0, 0, 4]),
801                   tensor([0, 1, 0, 2]),
802                   values([1, 2, 3, 4]),
803                   shape((2, 3)),
804                   r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.')
805
806            yield ('invalid compressed_indices.diff(dim=-1)',
807                   tensor([0, 5, 4]),
808                   tensor([0, 1, 0, 2]),
809                   values([1, 2, 3, 4]),
810                   shape((2, 3)),
811                   r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.')
812
813            yield ('invalid min(plain_indices)',
814                   tensor([0, 2, 4]),
815                   tensor([0, -1, 0, 3]),
816                   values([1, 2, 3, 4]),
817                   shape((2, 3)),
818                   r'`0 <= plain_indices < plain_dim` is not satisfied.')
819
820            yield ('invalid max(plain_indices)',
821                   tensor([0, 2, 4]),
822                   tensor([0, 1, 0, 3]),
823                   values([1, 2, 3, 4]),
824                   shape((2, 3)),
825                   r'`0 <= plain_indices < plain_dim` is not satisfied.')
826
827            yield ('non-coalesced',
828                   tensor([0, 2, 4]),
829                   tensor([1, 0, 0, 2]),
830                   values([1, 2, 3, 4]),
831                   shape((2, 3)),
832                   r'`plain_indices\[..., compressed_indices\[..., i - 1\]:compressed_indices\[..., i\]\] '
833                   'for all i = 1, ..., compressed_dim '
834                   'are sorted and distinct along the last dimension values` is not satisfied.')
835
836        if TEST_CUDA and torch.device(device).type == 'cpu':
837            yield ('indices and values mismatch of device',
838                   torch.tensor([0, 2, 4]),
839                   torch.tensor([0, 1, 0, 1]),
840                   values([1, 2, 3, 4], device='cuda'),
841                   shape((2, 3)),
842                   r'device of compressed_indices \(=cpu\) must match device of values \(=cuda:0\)')
843            yield ('compressed_indices and values mismatch of device',
844                   torch.tensor([0, 2, 4], device='cuda'),
845                   torch.tensor([0, 1, 0, 1]),
846                   values([1, 2, 3, 4]),
847                   shape((2, 3)),
848                   r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!')
849            yield ('compressed/plain_indices mismatch of device',
850                   torch.tensor([0, 2, 4], device='cuda'),
851                   torch.tensor([0, 1, 0, 1]),
852                   values([1, 2, 3, 4], device='cuda'),
853                   shape((2, 3)),
854                   r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!')
855
856        if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2:
857            yield ('indices and values mismatch of device index',
858                   torch.tensor([0, 2, 4], device='cuda:0'),
859                   torch.tensor([0, 1, 0, 1], device='cuda:0'),
860                   values([1, 2, 3, 4], device='cuda:1'),
861                   shape((2, 3)),
862                   r'device of compressed_indices \(=cuda:0\) must match device of values \(=cuda:1\)')
863            yield ('compressed_indices and values mismatch of device index',
864                   torch.tensor([0, 2, 4], device='cuda:0'),
865                   torch.tensor([0, 1, 0, 1], device='cuda:1'),
866                   values([1, 2, 3, 4], device='cuda:0'),
867                   shape((2, 3)),
868                   r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!')
869
870    @skipMeta
871    @all_sparse_compressed_layouts()
872    @parametrize('target', [subtest('validate_sparse_compressed_tensor_args'),
873                            subtest('sparse_compressed_tensor'),
874                            subtest('sparse_compressed_tensor_no_size')])
875    def test_invalid_input(self, layout, device, target):
876        for label, compressed_indices, plain_indices, values, size, errmsg in self._generate_invalid_input(layout, device):
877            if layout is torch.sparse_bsr:
878                errmsg = errmsg.replace('compressed_indices_name', 'row block').replace('plain_indices_name', 'column block')
879            elif layout is torch.sparse_bsc:
880                errmsg = errmsg.replace('compressed_indices_name', 'column block').replace('plain_indices_name', 'row block')
881            elif layout is torch.sparse_csr:
882                errmsg = errmsg.replace('compressed_indices_name', 'row').replace('plain_indices_name', 'column')
883            elif layout is torch.sparse_csc:
884                errmsg = errmsg.replace('compressed_indices_name', 'column').replace('plain_indices_name', 'row')
885            if layout in {torch.sparse_csr, torch.sparse_bsr}:
886                errmsg = errmsg.replace('compressed_indices', 'crow_indices') \
887                               .replace('plain_indices', 'col_indices') \
888                               .replace('plain_dim', 'ncols') \
889                               .replace('compressed_dim', 'nrows')
890            else:
891                errmsg = errmsg.replace('compressed_indices', 'ccol_indices') \
892                               .replace('plain_indices', 'row_indices') \
893                               .replace('plain_dim', 'nrows') \
894                               .replace('compressed_dim', 'ncols')
895
896            if target == 'sparse_compressed_tensor_no_size' and label in {
897                    'invalid size', 'invalid batchsize', 'invalid compressed_indices shape', 'invalid max(plain_indices)',
898                    'invalid blocksize'}:
899                # Skip invalid size input as a valid size is estimated for other inputs
900                continue
901
902            with self.assertRaisesRegex(RuntimeError, errmsg):
903                if target == 'validate_sparse_compressed_tensor_args':
904                    torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout)
905                elif target == 'sparse_compressed_tensor':
906                    torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout)
907                elif target == 'sparse_compressed_tensor_no_size':
908                    torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=layout)
909                else:
910                    raise NotImplementedError(target)
911
912    @skipMeta
913    @onlyCPU
914    @largeTensorTest("30GB", "cpu")
915    def test_invalid_input_csr_large(self):
916        rows = 2 ** 31
917        with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in row dimension'):
918            torch.sparse_csr_tensor(torch.arange(rows + 1, dtype=torch.int32) // rows,
919                                    torch.tensor([0], dtype=torch.int32),
920                                    torch.tensor([1]), (rows, 1))
921        torch.sparse_csr_tensor(torch.arange(rows + 1, dtype=torch.int64) // rows,
922                                torch.tensor([0], dtype=torch.int64),
923                                torch.tensor([1]), (rows, 1))
924
925        cols = 2 ** 31
926        with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in column dimension'):
927            torch.sparse_csr_tensor(torch.arange(2, dtype=torch.int32),
928                                    torch.tensor([0], dtype=torch.int32),
929                                    torch.tensor([1]), (1, cols))
930        torch.sparse_csr_tensor(torch.arange(2, dtype=torch.int64),
931                                torch.tensor([0], dtype=torch.int64),
932                                torch.tensor([1]), (1, cols))
933
934        nnz = 2 ** 31
935        with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in nnz'):
936            # nnz cannot be stored in int32 crow_indices
937            # but the `crow_indices[..., -1] == nnz`` check happens after the overflow validation
938            # So we can use `nnz - 1` here to avoid `value cannot be converted to type int32 without overflow`
939            # during construction of crow_indices
940            torch.sparse_csr_tensor(torch.tensor([0, nnz // 2, nnz - 1], dtype=torch.int32),
941                                    torch.arange(nnz // 2, dtype=torch.int32).repeat(2),
942                                    torch.ones(nnz, dtype=torch.int8), (2, nnz // 2))
943        torch.sparse_csr_tensor(torch.tensor([0, nnz // 2, nnz], dtype=torch.int64),
944                                torch.arange(nnz // 2, dtype=torch.int64).repeat(2),
945                                torch.ones(nnz, dtype=torch.int8), (2, nnz // 2))
946
947    @skipMeta
948    @onlyCPU
949    @all_sparse_compressed_layouts()
950    def test_dim(self, layout):
951        for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(layout, output_tensor=False):
952            size = kwargs['size']
953            batch_dim = compressed_indices.dim() - 1
954            sparse_dim = 2
955            block_dim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0
956            dense_dim = values.dim() - batch_dim - block_dim - 1
957            sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout)
958            self.assertEqual(sparse.sparse_dim(), sparse_dim)
959            self.assertEqual(sparse.dense_dim(), dense_dim)
960
961
962    @skipMeta
963    @all_sparse_compressed_layouts()
964    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
965    def test_to_dtype(self, layout, device, dtype):
966        # to_dense does not support hybrid inputs
967        for sparse in self.generate_simple_inputs(layout, dtype=dtype, device=device, enable_hybrid=False):
968            for to_dtype in all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16):
969                sparse_to_dtype = sparse.to(to_dtype)
970                dense_to_dtype = sparse.to_dense().to(to_dtype)
971                self.assertEqual(sparse_to_dtype.to_dense(), dense_to_dtype)
972
973    @skipMeta
974    @all_sparse_compressed_layouts()
975    @dtypes(torch.double)
976    def test_pickle(self, layout, dtype, device):
977        import pickle
978
979        for sparse in self.generate_simple_inputs(layout, device=device, dtype=dtype):
980            serialized = pickle.dumps(sparse)
981            sparse_loaded = pickle.loads(serialized)
982
983            self.assertEqual(sparse, sparse_loaded)
984
985    @all_sparse_compressed_layouts()
986    @parametrize("index_dtype", [torch.int32, torch.int64])
987    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
988    def test_select_copy(self, device, dtype, index_dtype, layout):
989
990        def is_view_of(base, other):
991            # a shameless copy of TestViewOps.is_view_of
992            if (
993                not other._is_view() or
994                other is base or
995                other._base is not base or
996                base.device != other.device
997            ):
998                return False
999            if base.device.type in ('cpu', 'cuda'):
1000                if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr():
1001                    return False
1002            return True
1003
1004        kwargs = dict(device=device, dtype=dtype, index_dtype=index_dtype)
1005        for sparse, dense in zip(self.generate_simple_inputs(layout, **kwargs),
1006                                 self.generate_simple_inputs(torch.strided, **kwargs)):
1007            if layout in {torch.sparse_csr, torch.sparse_bsr}:
1008                n_batchdim = sparse.crow_indices().ndim - 1
1009            elif layout in {torch.sparse_csc, torch.sparse_bsc}:
1010                n_batchdim = sparse.ccol_indices().ndim - 1
1011            else:
1012                assert 0  # unreachable
1013            self.assertEqual(sparse, dense)
1014            for dim in range(sparse.ndim):
1015                if sparse.shape[dim] == 0:
1016                    with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"):
1017                        torch.select_copy(sparse, dim, 0)
1018                    with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"):
1019                        torch.select_copy(dense, dim, 0)
1020                elif n_batchdim and dim >= n_batchdim and dim < n_batchdim + 2:
1021                    with self.assertRaisesRegex(
1022                            RuntimeError,
1023                            "selecting sparse dimensions is not supported for batched sparse compressed tensors"):
1024                        torch.select_copy(sparse, dim, 0)
1025                else:
1026                    for index in {0, sparse.shape[dim] // 2, sparse.shape[dim] - 1}:
1027                        dense_select = torch.select_copy(dense, dim, index)
1028                        sparse_select = torch.select_copy(sparse, dim, index)
1029                        self.assertEqual(sparse_select, dense_select)
1030                        self.assertFalse(is_view_of(sparse_select.values(), sparse.values()))
1031
1032
1033def _npref_block_addmm_addmv(c, a, b, alpha, beta):
1034    return alpha * (a @ b) + beta * c
1035
1036
1037class TestSparseCSR(TestCase):
1038
1039    def test_csr_stride(self):
1040        a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1041
1042        with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have strides"):
1043            a.stride()
1044
1045        with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have strides"):
1046            a.stride(-1)
1047
1048    def test_csr_storage(self):
1049        a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1050
1051        with self.assertRaisesRegex(RuntimeError, "Cannot access storage of SparseCsrTensorImpl"):
1052            a.storage()
1053
1054    def test_csr_is_contiguous(self):
1055        a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1056
1057        with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have is_contiguous"):
1058            a.is_contiguous()
1059
1060    @onlyCPU
1061    @largeTensorTest("20GB", "cpu")
1062    def test_csr_nnz(self):
1063        # Tests the limits of the number of specified elements in CSR tensors, see gh-102520.
1064        for nnz in [0, 2**31]:
1065            rows, cols = 1, max(nnz, 1)
1066            crow_indices = torch.tensor([0, nnz], dtype=torch.int64)
1067            col_indices = torch.arange(nnz, dtype=torch.int64)
1068            values = torch.ones(nnz, dtype=torch.int8)
1069            a = torch.sparse_csr_tensor(crow_indices, col_indices, values, (rows, cols))
1070            self.assertEqual(a._nnz(), nnz)
1071
1072    def test_csr_double_to_sparse_csr(self):
1073        a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64)
1074        a.to_sparse_csr().to_sparse_csr()
1075
1076    @all_sparse_compressed_layouts()
1077    @parametrize("index_dtype", [torch.int32, torch.int64])
1078    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
1079    def test_select(self, device, dtype, index_dtype, layout):
1080        compressed_indices_mth = {
1081            torch.sparse_csr: torch.Tensor.crow_indices,
1082            torch.sparse_bsr: torch.Tensor.crow_indices,
1083            torch.sparse_csc: torch.Tensor.ccol_indices,
1084            torch.sparse_bsc: torch.Tensor.ccol_indices,
1085        }[layout]
1086
1087        plain_indices_mth = {
1088            torch.sparse_csr: torch.Tensor.col_indices,
1089            torch.sparse_bsr: torch.Tensor.col_indices,
1090            torch.sparse_csc: torch.Tensor.row_indices,
1091            torch.sparse_bsc: torch.Tensor.row_indices,
1092        }[layout]
1093        create_tensor_mth = {
1094            torch.sparse_csr: torch.sparse_csr_tensor,
1095            torch.sparse_bsr: torch.sparse_bsr_tensor,
1096            torch.sparse_csc: torch.sparse_csc_tensor,
1097            torch.sparse_bsc: torch.sparse_bsc_tensor,
1098        }[layout]
1099
1100        shape = (2, 3, 6, 10)
1101        nnz = 6
1102        blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
1103        sparse = self.genSparseCompressedTensor(
1104            shape, nnz, device=device, layout=layout, dtype=dtype, index_dtype=index_dtype, blocksize=blocksize)
1105        comp_indices = compressed_indices_mth(sparse)
1106        plain_indices = plain_indices_mth(sparse)
1107        values = sparse.values()
1108
1109        # select from batch dimensions
1110        sparse_selected12 = sparse.select(1, 2)
1111        expected_sparse_selected12 = create_tensor_mth(comp_indices.select(1, 2).contiguous(),
1112                                                       plain_indices.select(1, 2).contiguous(),
1113                                                       values.select(1, 2).contiguous(),
1114                                                       size=(2, 6, 10),
1115                                                       dtype=dtype,
1116                                                       device=device)
1117        self.assertEqual(expected_sparse_selected12, sparse_selected12)
1118
1119        # selecting rows/col with batch dims not allowed
1120        sparse_non_batched = sparse[0, 0]
1121        # select from sparse dimensions
1122        for select_args in [(0, 0), (1, 1)]:
1123            sparse_selected = sparse_non_batched.select(*select_args)
1124            dense_selected = sparse_non_batched.to_dense().select(*select_args)
1125            self.assertEqual(dense_selected, sparse_selected)
1126
1127        self.assertEqual(sparse[0, 0, 0, 0], sparse.to_dense()[0, 0, 0, 0])
1128        # assigning to sparse through indexing is disabled
1129        with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"):
1130            sparse[0, 0, 0, 0] = 99.0
1131
1132        # select from sparse dimensions without removing batch dims
1133        msg = "selecting sparse dimensions is not supported for batched sparse compressed tensors."
1134        with self.assertRaisesRegex(RuntimeError, msg):
1135            sparse.select(-2, 0)
1136
1137        with self.assertRaisesRegex(RuntimeError, msg):
1138            sparse.select(-1, 0)
1139
1140    @skipMeta
1141    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1142    def test_resize(self, device, dtype):
1143
1144        def numel(tensor):
1145            r = 1
1146            for s in tensor.shape:
1147                r *= s
1148            return r
1149
1150        batch_shapes = [(), (2,), (2, 3)]
1151        for index_dtype, b in zip([torch.int32, torch.int64], batch_shapes):
1152            shape = (*b, 2, 3)
1153            nnz = 6
1154            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1155            self.assertEqual(a.numel(), numel(a))
1156
1157            new_shape = (*b, 4, 5)
1158            a.resize_(new_shape)
1159
1160            self.assertEqual(a.shape, new_shape)
1161            # resize to larger shape doesn't add specified elements
1162            self.assertEqual(a._nnz(), nnz)
1163            self.assertEqual(a.numel(), numel(a))
1164
1165            new_shape = (*b, 1, 5)
1166            a.resize_(new_shape)
1167
1168            self.assertEqual(a.shape, new_shape)
1169            # resize to smaller shape trims specified elements
1170            self.assertEqual(a._nnz(), 5)
1171            self.assertEqual(a.numel(), numel(a))
1172
1173            # trim batched dimensions
1174            a.resize_(new_shape[-2], new_shape[-1])
1175            self.assertEqual(a.shape, (new_shape[-2], new_shape[-1]))
1176            self.assertEqual(a._nnz(), 5)
1177            self.assertEqual(a.numel(), numel(a))
1178
1179    @skipMeta
1180    @dtypes(torch.float, torch.bool)
1181    @all_sparse_compressed_layouts()
1182    def test_resize_as_sparse_compressed(self, device, dtype, layout):
1183
1184        def _check_resize_b_as_a(b, a):
1185            br = b.clone()
1186            br.resize_as_sparse_(a)
1187
1188            # shape is inherited from a
1189            self.assertEqual(a.shape, br.shape)
1190            # other metadata is not affected
1191            self.assertEqual(b.layout, br.layout)
1192            self.assertEqual(b.device, br.device)
1193            self.assertEqual(b.dtype, br.dtype)
1194
1195            def _get_compressed_plain_inds(t):
1196                compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[t.layout]
1197                return compressed_indices_mth(t), plain_indices_mth(t)
1198
1199            br_compressed_indices, br_plain_indices = _get_compressed_plain_inds(br)
1200            br_values = br.values()
1201
1202            b_compressed_indices, b_plain_indices = _get_compressed_plain_inds(b)
1203            a_compressed_indices, a_plain_indices = _get_compressed_plain_inds(a)
1204            self.assertEqual(a_plain_indices.shape, br_plain_indices.shape)
1205            self.assertEqual(a_compressed_indices.shape, br_compressed_indices.shape)
1206            # We don't check the content of br_plain_indices and br_compressed_indices
1207            # because it is not well-defined (the content depends on the original
1208            # shape of `b` that `resize_as` ought to discard) nor needed (the
1209            # subsequent operation likely updates the indices and values of `b` anyway).
1210            # the device/dtype of indices should always be unaffected
1211            self.assertEqual(b_plain_indices.dtype, br_plain_indices.dtype)
1212            self.assertEqual(b_plain_indices.device, br_plain_indices.device)
1213            self.assertEqual(b_compressed_indices.dtype, br_compressed_indices.dtype)
1214            self.assertEqual(b_compressed_indices.device, br_compressed_indices.device)
1215            # values are generated empty, shape is updated
1216            self.assertEqual(a.values().shape, br_values.shape)
1217            # the device/dtype of indices should always be unaffected
1218            b_values = b.values()
1219            self.assertEqual(b_values.dtype, br_values.dtype)
1220            self.assertEqual(b_values.device, br_values.device)
1221            # nnz will be picked up from a via new shape of values
1222            self.assertEqual(a._nnz(), br._nnz())
1223
1224            # post resize the invariants of the layout are respected
1225            torch._validate_sparse_compressed_tensor_args(br_compressed_indices, br_plain_indices, br_values, br.shape,
1226                                                          br.layout)
1227
1228        block_sparse = layout in (torch.sparse_bsr, torch.sparse_bsc)
1229        shape = (2, 1, 6, 4)
1230        nnz = 4
1231        blocksize = (2, 1) if block_sparse else ()
1232        for index_dtype in [torch.int32, torch.int64]:
1233            a = self.genSparseCompressedTensor(shape,
1234                                               layout=layout,
1235                                               device=device,
1236                                               index_dtype=index_dtype,
1237                                               dtype=dtype,
1238                                               nnz=nnz,
1239                                               blocksize=blocksize)
1240
1241            # same size, resize should not trigger
1242            b = self.genSparseCompressedTensor(shape,
1243                                               layout=layout,
1244                                               device=device,
1245                                               index_dtype=index_dtype,
1246                                               dtype=dtype,
1247                                               nnz=nnz,
1248                                               blocksize=blocksize)
1249
1250            # This test will not always trigger a resize, if the layouts are the same nothing should happen to b.
1251            # The invariants of the function as checked should still hold
1252            _check_resize_b_as_a(b, a)
1253
1254            # same ndim, but bigger, more nnz, different dtype, different blocksize if blocked
1255            b = self.genSparseCompressedTensor(tuple(s * 2 for s in shape),
1256                                               layout=layout,
1257                                               device=device,
1258                                               dtype=torch.chalf,
1259                                               index_dtype=torch.int64 if index_dtype == torch.int32 else torch.int32,
1260                                               nnz=nnz * 2,
1261                                               blocksize=tuple(2 * bi for bi in blocksize))
1262            _check_resize_b_as_a(b, a)
1263
1264            # different device, only check on cuda pass as we know we are testing in an environment
1265            # that has multiple devices
1266
1267            # TODO: .cpu() does not seem to work correctly for sparse. Causes a call to `copy_` which
1268            # complains about incompatible nnz between src and self?
1269            if torch.device(device).type == 'cuda' and (layout not in (torch.sparse_bsc, torch.sparse_bsr)):
1270                a_cpu = self.genSparseCompressedTensor(shape,
1271                                                       layout=layout,
1272                                                       device='cpu',
1273                                                       index_dtype=index_dtype,
1274                                                       dtype=dtype,
1275                                                       nnz=nnz,
1276                                                       blocksize=blocksize)
1277                _check_resize_b_as_a(b, a)
1278
1279            # error on a strided
1280            a_strided = a.to_dense()
1281            with self.assertRaisesRegex(
1282                    RuntimeError, r'resize_as_sparse_compressed_: src  expected sparse compressed tensor layout'):
1283                b.resize_as_sparse_(a_strided)
1284
1285            # error on b strided
1286            b_strided = b.to_dense()
1287            with self.assertRaisesRegex(
1288                    RuntimeError, r'resize_as_sparse_compressed_: self  expected sparse compressed tensor layout'):
1289                b_strided.resize_as_sparse_(a)
1290
1291            # error if layout does not match, transpose induces layout flip
1292            with self.assertRaisesRegex(RuntimeError,
1293                                        r"resize_as_sparse_compressed_tensor_: self and src must have the same layout"):
1294                b.transpose(-2, -1).resize_as_sparse_(a)
1295            with self.assertRaisesRegex(RuntimeError,
1296                                        r"resize_as_sparse_compressed_tensor_: self and src must have the same layout"):
1297                b.resize_as_sparse_(a.transpose(-2, -1))
1298
1299    @skipMeta
1300    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1301    def test_resize_errors(self, device, dtype):
1302        for index_dtype in [torch.int32, torch.int64]:
1303            shape = (2, 3)
1304            nnz = 6
1305            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1306
1307            with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only batched sparse CSR matrices are supported"):
1308                new_shape = (4,)
1309                a.resize_(new_shape)
1310
1311            # resizing of columns to smaller size is not implemented
1312            with self.assertRaisesRegex(
1313                RuntimeError,
1314                "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported.",
1315            ):
1316                new_shape = (2, 2)
1317                a.resize_(new_shape)
1318
1319    @skipIfTorchDynamo()
1320    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1321    def test_sparse_csr_from_dense(self, device, dtype):
1322        dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], dtype=dtype, device=device)
1323        sparse = dense.to_sparse_csr()
1324        self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64), sparse.crow_indices())
1325        self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64), sparse.col_indices())
1326        self.assertEqual(torch.tensor([4, 5, 1], dtype=dtype), sparse.values())
1327
1328        dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=dtype, device=device)
1329        sparse = dense.to_sparse_csr()
1330        self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64), sparse.crow_indices())
1331        self.assertEqual(torch.tensor([2, 0], dtype=torch.int64), sparse.col_indices())
1332        self.assertEqual(torch.tensor([1, 1], dtype=dtype), sparse.values())
1333
1334        dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]], dtype=dtype, device=device)
1335        sparse = dense.to_sparse_csr()
1336        self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64), sparse.crow_indices())
1337        self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices())
1338        self.assertEqual(torch.tensor([2] * 9, dtype=dtype), sparse.values())
1339
1340    def _test_sparse_compressed_to_dense(self, device, dtype, layout):
1341        compressed_format_str = str(layout)[-3:]
1342
1343        def to_compressed(t):
1344            return getattr(t, f"to_sparse_{compressed_format_str}")()
1345
1346        def compressed_constructor(*input, **kwargs):
1347            constructor = getattr(torch, f"sparse_{compressed_format_str}_tensor")
1348            return constructor(*input, **kwargs)
1349
1350        def get_dense_shape(shape, batch_ndim):
1351            if layout is torch.sparse_csc:
1352                compressed_dims_slice = slice(batch_ndim + 1, batch_ndim - 1, -1)
1353            else:
1354                compressed_dims_slice = slice(batch_ndim, batch_ndim + 2)
1355            return shape[:batch_ndim] + shape[compressed_dims_slice] + shape[batch_ndim + 2:]
1356
1357        def transpose(t, batch_ndim):
1358            if layout is torch.sparse_csc:
1359                return t.transpose(batch_ndim, batch_ndim + 1)
1360            return t
1361
1362        mn = [5, 2, 0]
1363        for (m, n) in itertools.product(mn, mn):
1364            size = (m, n)
1365            dense = make_tensor(size, dtype=dtype, device=device)
1366            sparse = to_compressed(dense)
1367            self.assertEqual(sparse.to_dense(), dense)
1368
1369        batch_shape = (2, 3)
1370        compressed_indices = torch.tensor([0, 3, 5], device=device).repeat(6, 1).reshape(*batch_shape, -1)
1371        plain_indices = torch.tensor([0, 1, 2, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
1372        values = torch.tensor([1, 2, 1, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
1373        sparse = compressed_constructor(compressed_indices, plain_indices, values, dtype=dtype, device=device)
1374        dense_shape = get_dense_shape(sparse.shape, len(batch_shape))
1375        dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device).repeat(6, 1).reshape(dense_shape)
1376        self.assertEqual(sparse.to_dense(), transpose(dense, len(batch_shape)))
1377
1378    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1379    def test_sparse_csr_to_dense(self, device, dtype):
1380        self._test_sparse_compressed_to_dense(device, dtype, torch.sparse_csr)
1381
1382    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
1383    def test_sparse_csc_to_dense(self, device, dtype):
1384        self._test_sparse_compressed_to_dense(device, dtype, torch.sparse_csc)
1385
1386    @skipMeta
1387    @skipCPUIfNoMklSparse
1388    @coalescedonoff
1389    @dtypes(torch.double)
1390    def test_coo_to_csr_convert(self, device, dtype, coalesced):
1391        with self.assertRaisesRegex(RuntimeError, "Input is supposed to be a vector"):
1392            torch._convert_indices_from_coo_to_csr(
1393                torch.randint(100, (5, 5), device=device),
1394                size=100)
1395
1396        size = (5, 5)
1397        sparse_dim = 2
1398        nnz = 10
1399        sparse_coo, _, _ = self.genSparseTensor(size, sparse_dim, nnz, coalesced, device, dtype)
1400        sparse_csr = sparse_coo.to_sparse_csr()
1401
1402        self.assertTrue(sparse_csr.is_sparse_csr)
1403        self.assertEqual(sparse_csr.to_dense(), sparse_coo.to_dense())
1404
1405        vec = torch.randn((5, 1), dtype=dtype, device=device)
1406        coo_product = sparse_coo.matmul(vec)
1407        csr_product = sparse_csr.matmul(vec)
1408
1409        self.assertEqual(coo_product, csr_product)
1410
1411        vec = torch.randn((100, 1), dtype=dtype, device=device)
1412        index = torch.tensor([
1413            [1, 0, 35, 14, 39, 6, 71, 66, 40, 27],
1414            [92, 31, 62, 50, 22, 65, 89, 74, 56, 34],
1415        ], dtype=torch.int32)
1416        values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device)
1417        coo = torch.sparse_coo_tensor(index, values, torch.Size([100, 100]), dtype=dtype, device=device)
1418        csr = coo.to_sparse_csr()
1419
1420        self.assertEqual(coo.matmul(vec), csr.matmul(vec))
1421
1422        col_indices = torch.tensor([
1423            31, 92, 65, 50, 34, 62, 22, 56, 74, 89
1424        ], dtype=torch.int64, device=device)
1425        self.assertEqual(csr.col_indices(), col_indices)
1426
1427        values = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7], dtype=dtype, device=device)
1428        self.assertEqual(csr.values(), values)
1429
1430    @parametrize("blocksize", [2, 4])
1431    @dtypes((torch.double, torch.int32), (torch.double, torch.int64))
1432    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1433    @skipMeta
1434    def test_csr_to_block_csr(self, device, dtypes, blocksize):
1435        for shape in [(24, 24), (12, 24)]:
1436            dtype, index_dtype = dtypes
1437            m, k = shape
1438            nnz = random.randint(0, m * k)
1439            t = self.genSparseCSRTensor((m * blocksize, k * blocksize), nnz, dtype=dtype,
1440                                        device=device, index_dtype=index_dtype)
1441            st = sp.csr_matrix((t.values().cpu(), t.col_indices().cpu(), t.crow_indices().cpu()), shape=tuple(t.size()))
1442            block_t = t.to_sparse_bsr((blocksize, blocksize))
1443            self.assertEqual(block_t.values().dim(), 3)
1444            self.assertTrue(block_t.layout == torch.sparse_bsr)
1445            block_st = st.tobsr(blocksize=(blocksize, blocksize))
1446            block_st.sort_indices()
1447            self.assertEqual(block_t.values().cpu(), block_st.data)
1448            self.assertEqual(block_t.col_indices().cpu(), torch.tensor(block_st.indices).to(index_dtype))
1449            self.assertEqual(block_t.crow_indices().cpu(), torch.tensor(block_st.indptr).to(index_dtype))
1450
1451    @dtypes(torch.double)
1452    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1453    def test_csr_to_block_csr_errors(self, device, dtype):
1454        for index_dtype in [torch.int32, torch.int64]:
1455            nnz = 15
1456            t = self.genSparseCSRTensor((16, 16), nnz, dtype=dtype,
1457                                        device=device, index_dtype=index_dtype)
1458
1459            with self.assertRaisesRegex(RuntimeError,
1460                                        r"tensor sparse size \(.*,.*\) must be divisible by given blocksize \(.*,.*\)"):
1461                block_t = t.to_sparse_bsr((5, 5))
1462
1463    # TODO: Support auto generation of device check for sparse tensors
1464    # See: https://github.com/pytorch/pytorch/issues/59058
1465    @onlyCUDA
1466    @dtypes(torch.double)
1467    def test_matmul_device_mismatch(self, device, dtype):
1468        cpu = torch.rand((10, 10))
1469        cuda = cpu.cuda()
1470        for s, m1, m2 in itertools.product((cpu, cuda), repeat=3):
1471            csr = m1.to_sparse()
1472            if s.device == csr.device == m2.device:
1473                torch.addmm(s, csr, m2)
1474            else:
1475                with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
1476                    torch.addmm(s, csr, m2)
1477
1478    @skipCPUIfNoMklSparse
1479    @skipCUDAIfNoSparseGeneric
1480    @dtypes(*floating_and_complex_types())
1481    @dtypesIfCUDA(*floating_and_complex_types_and(
1482                  *[torch.half] if SM53OrLater else [],
1483                  *[torch.bfloat16] if SM80OrLater else []))
1484    def test_csr_matvec(self, device, dtype):
1485
1486        if TEST_WITH_ROCM and (dtype == torch.half or dtype == torch.bfloat16):
1487            self.skipTest("ROCm doesn't work with half dtypes correctly.")
1488
1489        side = 100
1490        for index_dtype in [torch.int32, torch.int64]:
1491            csr = self.genSparseCSRTensor((side, side), 1000, device=device, dtype=dtype, index_dtype=index_dtype)
1492            vec = torch.randn(side, dtype=dtype, device=device)
1493
1494            res = csr.matmul(vec)
1495            expected = csr.to_dense().matmul(vec)
1496
1497            self.assertEqual(res, expected)
1498
1499            bad_vec = torch.randn(side + 10, dtype=dtype, device=device)
1500            err_msg = "size mismatch, got"
1501            with self.assertRaisesRegex(RuntimeError, err_msg):
1502                csr.matmul(bad_vec)
1503
1504    @onlyCUDA
1505    # hmm, the test passes ok on CUDA when Rocm is not available:
1506    @skipCUDAIfRocmVersionLessThan((5, 2))
1507    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1508    def test_baddbmm(self, device, dtype):
1509
1510        # TODO: disable the invariant checks within torch.baddbmm that
1511        # constructs unconventional csr tensors leading to
1512        # RuntimeError: tensor dimensionality must be sum of batch,
1513        #     base, and dense dimensionalities (=0 + 2 + 0) but got 3
1514        # when invariant checking is enabled. When done, undecorate run_test.
1515        @torch.sparse.check_sparse_tensor_invariants(enable=False)
1516        def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
1517            alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1518            beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1519            b = b.mH if (op_b and a.shape == b.shape) else b
1520
1521            actual = torch.baddbmm(c, a_batched, b, alpha=alpha, beta=beta)
1522
1523            out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
1524            torch.baddbmm(c, a_batched, b, alpha=alpha, beta=beta, out=out)
1525
1526            expected = [torch.addmm(c[i], a, b[i], alpha=alpha, beta=beta) for i in range(c.shape[0])]
1527            expected = torch.stack(expected, 0)
1528
1529            self.assertEqual(actual, out)
1530            self.assertEqual(actual, expected)
1531
1532        for index_dtype in [torch.int32, torch.int64]:
1533            for (m, n, k), batch_size, noncontiguous in zip(itertools.product([2, 5], repeat=3), [1, 3], [True, False]):
1534                nnz = random.randint(0, m * k)
1535                a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1536
1537                # a_batched is a regular CSR tensor but with a batch dimension in the shape
1538                a_batched = torch.sparse_csr_tensor(
1539                    a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)
1540
1541                b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1542                c = make_tensor((batch_size, m, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1543                for op_b, op_out in itertools.product([True, False], repeat=2):
1544                    run_test(c, a, a_batched, b, op_b, op_out, dtype=dtype, device=device)
1545
1546    @onlyCUDA
1547    @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
1548    @skipCUDAIfNoSparseGeneric
1549    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1550    def test_bmm(self, device, dtype):
1551        def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None):
1552            b = b.mH if (op_b and a.shape == b.shape) else b
1553
1554            actual = torch.bmm(a_batched, b)
1555
1556            out = torch.empty_like(actual.mH if op_out and a.shape == b.shape else actual)
1557            torch.bmm(a_batched, b, out=out)
1558
1559            expected = [torch.mm(a, b[i]) for i in range(b.shape[0])]
1560            expected = torch.stack(expected, 0)
1561
1562            self.assertEqual(actual, out)
1563            self.assertEqual(actual, expected)
1564
1565        for index_dtype in [torch.int32, torch.int64]:
1566            for (m, n, k), batch_size, noncontiguous in zip(itertools.product([2, 5], repeat=3), [1, 3], [True, False]):
1567                nnz = random.randint(0, m * k)
1568                a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1569
1570                # a_batched is a regular CSR tensor but with a batch
1571                # dimension in the shape. It is unorthodox in PyTorch
1572                # to represent a batch sparse tensor in this way,
1573                # hence checking the tensor invariants is locally
1574                # turned off.
1575                a_batched = torch.sparse_csr_tensor(
1576                    a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False)
1577
1578                b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
1579                for op_b, op_out in itertools.product([True, False], repeat=2):
1580                    run_test(a, a_batched, b, op_b, op_out, dtype=dtype, device=device)
1581
1582    def run_test_block_addmm_addmv(self,
1583                                   addmv_addmm,
1584                                   c,
1585                                   a,
1586                                   b,
1587                                   op_b=False,
1588                                   op_out=False,
1589                                   *,
1590                                   dtype=None,
1591                                   device=None,
1592                                   ref=_npref_block_addmm_addmv):
1593        alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1594        beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
1595        b = b.mH if (op_b and a.shape == b.shape) else b
1596
1597        actual = addmv_addmm(c, a, b, alpha=alpha, beta=beta)
1598
1599        out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
1600        addmv_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
1601        expected = ref(c, a, b, alpha, beta)
1602
1603        self.assertEqual(actual, out)
1604        self.assertEqual(actual, expected, lambda msg: f"{msg}\na={a}\nc={c}\nb={b}\nalpha={alpha} beta={beta}")
1605
1606    # TODO: block_size 1 is broken
1607    @parametrize("block_size", [2, 3])
1608    @parametrize("index_dtype", [torch.int32, torch.int64])
1609    @parametrize("noncontiguous", [True, False])
1610    @skipCPUIfNoMklSparse
1611    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1612    @skipIfTorchDynamo("raises 'sparse matrix length is ambiguous; use getnnz()'")
1613    @dtypes(*floating_and_complex_types())
1614    @dtypesIfCUDA(*floating_and_complex_types_and(
1615                  *[torch.half] if SM53OrLater else [],
1616                  *[torch.bfloat16] if SM80OrLater else []))
1617    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
1618                        torch.float64: 1e-5, torch.complex128: 1e-5,
1619                        torch.float16: 1e-3, torch.bfloat16: 1e-3})
1620    def test_block_addmm(self, device, dtype, index_dtype, block_size, noncontiguous):
1621
1622        def make_transposed_addmm_op(f):
1623
1624            def tt(t):
1625                if isinstance(t, torch.Tensor):
1626                    return t.transpose(-2, -1)
1627                else:
1628                    # assume numpy/scipy spmatrix
1629                    return t.transpose()
1630
1631            @functools.wraps(f)
1632            def wrapper(c, a, b, alpha=None, beta=None, out=None):
1633                if out is not None:
1634                    # the ref takes no out kwarg
1635                    assert isinstance(out, torch.Tensor)
1636                    # transpose inplace to propagate out to checking context
1637                    out.transpose_(-2, -1)
1638                    return f(tt(c), tt(b), tt(a), alpha=alpha, beta=beta, out=out)
1639                else:
1640                    return f(tt(c), tt(b), tt(a), alpha=alpha, beta=beta)
1641
1642            return wrapper
1643
1644        def ref_sp_numpy(c, a, b, alpha=None, beta=None, out=None):
1645
1646            def prep_input(t):
1647
1648                def to_sp_block_compressed(t):
1649
1650                    if t.layout is torch.sparse_bsc:
1651                        tt = t.transpose(-1, -2)
1652                    else:
1653                        tt = t
1654
1655                    t_sp_bsr = sp.bsr_matrix(
1656                        (
1657                            tt.values().cpu().numpy(),
1658                            tt.col_indices().cpu().numpy(),
1659                            tt.crow_indices().cpu().numpy(),
1660                        ),
1661                        shape=tt.shape,
1662                    )
1663
1664                    if t.layout is torch.sparse_bsc:
1665                        return t_sp_bsr.transpose()
1666                    else:
1667                        return t_sp_bsr
1668
1669                if t.layout is not torch.strided:
1670                    return to_sp_block_compressed(t)
1671                else:
1672                    return t.cpu().resolve_conj().numpy()
1673
1674            res = _npref_block_addmm_addmv(
1675                *(prep_input(t) for t in (c, a, b)),
1676                alpha,
1677                beta
1678            )
1679
1680            if out is not None:
1681                out.copy_(res)
1682                return out
1683            else:
1684                return res
1685
1686        def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None):
1687            res = alpha * (a.to_dense().to(torch.float32) @ b.to_dense().to(torch.float32)).to(a.dtype) + beta * c
1688            if out is not None:
1689                out.copy_(res)
1690                return out
1691            else:
1692                return res
1693
1694        if dtype in (torch.half, torch.bfloat16):
1695            ref = ref_half_bfloat16
1696        else:
1697            ref = ref_sp_numpy
1698
1699        for (m, n, k) in itertools.product([2, 5], repeat=3):
1700            nnz = random.randint(0, m * k)
1701            a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1702            a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1703            a_data = a_data.mT if noncontiguous else a_data
1704            a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1705                                        a_data, (m * block_size, k * block_size), check_invariants=False)
1706            b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
1707            c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
1708            for op_b, op_out in itertools.product([True, False], repeat=2):
1709                self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device, ref=ref)
1710                self.run_test_block_addmm_addmv(make_transposed_addmm_op(torch.addmm),
1711                                                c,
1712                                                a,
1713                                                b,
1714                                                op_b,
1715                                                op_out,
1716                                                dtype=dtype,
1717                                                device=device,
1718                                                ref=make_transposed_addmm_op(ref))
1719
1720    @parametrize("block_size", [2, 3])
1721    @parametrize("index_dtype", [torch.int32, torch.int64])
1722    @parametrize("noncontiguous", [True, False])
1723    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1724    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1725    def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous):
1726        # TODO: Explicitly disable block size 1 support
1727        # if (TEST_WITH_ROCM or not TEST_CUSPARSE_GENERIC) and block_size == 1:
1728        #     return
1729        def ref_block_addmv(c, a, b, alpha, beta):
1730            return _npref_block_addmm_addmv(c, a.to_dense(), b, alpha, beta)
1731
1732        for (m, k) in itertools.product([2, 5], repeat=2):
1733            nnz = random.randint(0, m * k)
1734            if not noncontiguous:
1735                a = self.genSparseCSRTensor((m * block_size, k * block_size), nnz,
1736                                            dtype=dtype, device=device, index_dtype=index_dtype)
1737                a = a.to_sparse_bsr((block_size, block_size))
1738            else:
1739                a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1740                a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1741                a_data = a_data.mT if noncontiguous else a_data   # Test column-major blocks
1742                a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1743                                            a_data, (m * block_size, k * block_size), check_invariants=False)
1744            b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
1745            c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
1746            self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device, ref=ref_block_addmv)
1747
1748    @parametrize("matrix_shape", [(3, 3), (5, 7), (11, 9)], name_fn=lambda x: "shape_{}x{}".format(*x))
1749    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1750    @onlyCPU
1751    def test_addmv(self, device, dtype, matrix_shape):
1752        mat = torch.randn(matrix_shape, dtype=dtype, device=device)
1753        mat[mat.real < 0] = 0
1754        sparse_mat = mat.to_sparse_csr()
1755        mvec = torch.randn((mat.size(1),), dtype=dtype, device=device)
1756        avec = torch.randn((mat.size(0),), dtype=torch.float64, device=device)
1757        ref_output = torch.addmv(avec, mat, mvec)
1758        output = torch.addmv(avec, sparse_mat, mvec)
1759        self.assertEqual(ref_output, output)
1760
1761    @parametrize("block_size", [2, 3])
1762    @parametrize("index_dtype", [torch.int32, torch.int64])
1763    @parametrize("noncontiguous", [True, False])
1764    @skipCPUIfNoMklSparse
1765    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
1766    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
1767    def test_block_triangular_solve(self, device, dtype, index_dtype, block_size, noncontiguous):
1768        def run_test(a, b, upper, transpose, unitriangular, op_out):
1769            if unitriangular and self.device_type == 'cpu':
1770                # TODO: When unitriangular=True results are not correct on CPU
1771                return
1772
1773            if not upper and self.device_type == 'cpu':
1774                # TODO: When upper=False some generated inputs might crash on CPU
1775                return
1776
1777            actual = torch.triangular_solve(b, a, upper=upper, unitriangular=unitriangular, transpose=transpose)
1778            actual_X = actual.solution
1779            actual_A_clone = actual.cloned_coefficient
1780            self.assertTrue(actual_A_clone.numel() == 0)
1781            if a._nnz() == 0:
1782                self.assertTrue(actual_X.isnan().all())
1783                return
1784
1785            # TODO: replace with torch method when implemented to_dense() on block sparse tensor
1786            a_bsr = sp.bsr_matrix(
1787                (
1788                    a.values().cpu().numpy(),
1789                    a.col_indices().cpu().numpy(),
1790                    a.crow_indices().cpu().numpy(),
1791                ),
1792                shape=a.shape,
1793            )
1794            expected_X, _ = torch.triangular_solve(
1795                b,
1796                torch.tensor(a_bsr.todense(), device=device),
1797                transpose=transpose,
1798                upper=upper,
1799                unitriangular=unitriangular)
1800
1801            if expected_X.isnan().any():
1802                # TODO: zeros on the diagonal are not handled for CPU path
1803                # there's no way to query this info from MKL
1804                if self.device_type == 'cuda' and not TEST_WITH_ROCM:
1805                    self.assertTrue(actual_X.isnan().any() or actual_X.isinf().any())
1806                return
1807
1808            self.assertEqual(actual_X, expected_X)
1809
1810            out = torch.empty_like(b.mH if op_out and a.shape == b.shape else b)
1811            torch.triangular_solve(
1812                b, a,
1813                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
1814            )
1815            self.assertEqual(out, actual_X)
1816            self.assertEqual(out, expected_X)
1817
1818        for (m, k) in itertools.product([2, 3], [1, 3]):
1819            nnz = random.randint(0, m * m)
1820            if not noncontiguous:
1821                a = self.genSparseCSRTensor((m * block_size, m * block_size), nnz,
1822                                            dtype=dtype, device=device, index_dtype=index_dtype)
1823                a = a.to_sparse_bsr((block_size, block_size))
1824            else:
1825                a = self.genSparseCSRTensor((m, m), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1826                a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
1827                a_data = a_data.mT if noncontiguous else a_data  # Test column-major blocks
1828                a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(),
1829                                            a_data, (m * block_size, m * block_size), check_invariants=False)
1830            b = make_tensor((m * block_size, k), dtype=dtype, device=device, noncontiguous=noncontiguous)
1831
1832            for (upper, unitriangular, transpose, op_out) in itertools.product([True, False], repeat=4):
1833                run_test(a, b, upper, unitriangular, transpose, op_out)
1834
1835    @skipCPUIfNoMklSparse
1836    @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported")
1837    @dtypes(torch.double)
1838    def test_mm(self, device, dtype):
1839        def test_shape(di, dj, dk, nnz0=None, nnz1=None):
1840            for index_dtype in [torch.int32, torch.int64]:
1841                alpha = random.random()
1842                beta = random.random()
1843
1844                def _test_addmm(t, x, y):
1845                    # TODO: addmm doesn't support strided result for sparse inputs.
1846                    # res = beta * t  + alpha * (x @ y)
1847                    res = torch.addmm(t, x, y, beta=beta, alpha=alpha)
1848                    expected = torch.addmm(t, x.to_dense(), y.to_dense(), beta=beta, alpha=alpha)
1849                    self.assertEqual(res, expected)
1850
1851                    res = torch.addmm(t, x, y)
1852                    expected = torch.addmm(t, x.to_dense(), y.to_dense())
1853                    self.assertEqual(res, expected)
1854
1855                def _test_mm(x, y):
1856                    res = torch.mm(x, y)
1857                    expected = torch.mm(x.to_dense(), y.to_dense())
1858                    if x.layout is torch.strided or y.layout is torch.strided:
1859                        self.assertEqual(res.layout, torch.strided)
1860                    else:
1861                        self.assertEqual(res.layout, torch.sparse_csr)
1862                    self.assertEqual(res.to_dense(), expected)
1863
1864                def _test(t, x, y):
1865                    _test_addmm(t, x, y)
1866                    _test_mm(x, y)
1867
1868                if nnz0 is None:
1869                    nnz0 = random.randint(di * dk // 2, di * dk)
1870                t = torch.randn(di, dj, dtype=dtype, device=device)
1871                x = self.genSparseCSRTensor((di, dk), nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1872                y = torch.randn(dk, dj, dtype=dtype, device=device)
1873                _test(t, x, y)
1874
1875                t = torch.randn(di, dj, dtype=dtype, device=device)
1876                x = self.genSparseCSCTensor((di, dk), nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1877                y = torch.randn(dk, dj, dtype=dtype, device=device)
1878                _test(t, x, y)
1879
1880                if nnz1 is None:
1881                    nnz1 = random.randint(dk * dj // 2, dk * dj)
1882                t = torch.randn(di, dj, dtype=dtype, device=device)
1883                x = torch.randn(di, dk, dtype=dtype, device=device)
1884                y = self.genSparseCSRTensor((dk, dj), nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1885                _test(t, x, y)
1886
1887                t = torch.randn(di, dj, dtype=dtype, device=device)
1888                x = torch.randn(di, dk, dtype=dtype, device=device)
1889                y = self.genSparseCSCTensor((dk, dj), nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1890                _test(t, x, y)
1891
1892                x_shape, y_shape = x.shape, y.shape
1893
1894                gen_csr_csc = [self.genSparseCSRTensor, self.genSparseCSCTensor]
1895
1896                # Test mm({CSR, CSC}, {CSR, CSC})
1897                for gen_x, gen_y in itertools.product(gen_csr_csc, gen_csr_csc):
1898                    x = gen_x(x_shape, nnz0, device=device, dtype=dtype, index_dtype=index_dtype)
1899                    y = gen_y(y_shape, nnz1, device=device, dtype=dtype, index_dtype=index_dtype)
1900                    _test_mm(x, y)
1901
1902        def test_empty_inputs(lhs_layout, rhs_layout):
1903            xd = torch.rand(10, 0, device=device, dtype=dtype)
1904            yd = xd.transpose(-2, -1)
1905            zd = torch.rand(0, 0, device=device, dtype=dtype)
1906
1907            xls, yls, zls = (t.to_sparse(layout=lhs_layout) for t in (xd, yd, zd))
1908            xrs, yrs, zrs = (t.to_sparse(layout=rhs_layout) for t in (xd, yd, zd))
1909
1910            for ls, rs, ld, rd in [(xls, yrs, xd, yd), (xls, zrs, xd, zd), (zls, yrs, zd, yd), (zls, zrs, zd, zd)]:
1911                res_sparse = ls @ rs
1912                res_dense = ld @ rd
1913                self.assertEqual(res_sparse.to_dense(), res_dense)
1914
1915        def test_orthogonal_inputs(lhs_layout, rhs_layout):
1916            ones = torch.ones(2, 2, device=device, dtype=dtype)
1917            zeros = torch.zeros(2, 2, device=device, dtype=dtype)
1918            x = torch.cat((ones, zeros), -1).to_sparse(layout=lhs_layout)
1919            y = torch.cat((zeros, ones), -2).to_sparse(layout=rhs_layout)
1920            res = x @ y
1921            res_expected = torch.zeros(*res.shape, device=device, dtype=dtype, layout=res.layout)
1922            self.assertEqual(res, res_expected)
1923
1924        for lhs_layout, rhs_layout in itertools.product([torch.sparse_csr, torch.sparse_csc], repeat=2):
1925            test_empty_inputs(lhs_layout, rhs_layout)
1926            test_orthogonal_inputs(lhs_layout, rhs_layout)
1927
1928        for i in [2, 4]:
1929            for j in [2, 4, 7]:
1930                for k in [2, 3, 7]:
1931                    test_shape(i, j, k)
1932        test_shape(4, 4, 4, 0, 0)
1933
1934    @skipCPUIfNoMklSparse
1935    @dtypes(*floating_and_complex_types())
1936    @dtypesIfCUDA(*floating_and_complex_types_and(
1937                  *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
1938                  *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1939    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
1940    def test_sparse_mm(self, device, dtype):
1941        def test_shape(d1, d2, d3, nnz, transposed, index_dtype):
1942            if transposed:
1943                D = torch.randn(d3, d2, dtype=dtype, device=device).t_()
1944            else:
1945                D = torch.randn(d2, d3, dtype=dtype, device=device)
1946            S = self.genSparseCSRTensor((d1, d2), nnz, device=device, dtype=dtype, index_dtype=index_dtype)
1947            S_dense = S.to_dense()
1948            self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D))
1949
1950        for index_dtype in [torch.int32, torch.int64]:
1951            test_shape(7, 8, 9, 20, False, index_dtype)
1952            test_shape(7, 8, 9, 20, True, index_dtype)
1953
1954    @dtypes(*floating_and_complex_types())
1955    @dtypesIfCUDA(*floating_and_complex_types_and(
1956                  *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [],
1957                  *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else []))
1958    @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2})
1959    def test_sparse_addmm(self, device, dtype):
1960        def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None):
1961            if alpha_beta is None:
1962                alpha = random.random()
1963                beta = random.random()
1964            else:
1965                alpha, beta = alpha_beta
1966            if broadcast:
1967                D1 = make_tensor((), dtype=dtype, device=device)
1968            else:
1969                D1 = make_tensor([n, p], dtype=dtype, device=device)
1970            D2 = make_tensor([m, p], dtype=dtype, device=device)
1971            S = self.genSparseCSRTensor([n, m], nnz, dtype=dtype, device=device, index_dtype=index_dtype)
1972            S_dense = S.to_dense()
1973            Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha)
1974            Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha)
1975            self.assertEqual(Y, Y_dense)
1976
1977        for index_dtype in [torch.int32, torch.int64]:
1978            test_shape(7, 8, 9, 20, False, index_dtype, None)
1979            test_shape(7, 8, 9, 20, True, index_dtype, None)
1980            test_shape(7, 8, 9, 20, False, index_dtype, (1, 0))
1981            test_shape(7, 8, 9, 20, True, index_dtype, (1, 0))
1982            test_shape(7, 8, 9, 20, False, index_dtype, (1, 1))
1983            test_shape(7, 8, 9, 20, True, index_dtype, (1, 1))
1984
1985    @skipCPUIfNoMklSparse
1986    @dtypes(*floating_and_complex_types())
1987    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
1988                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
1989    @dtypesIfCUDA(*floating_types_and(torch.complex64,
1990                                      *[torch.bfloat16] if SM80OrLater else [],
1991                                      *[torch.half] if SM53OrLater else [],
1992                                      *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
1993    @sparse_compressed_nonblock_layouts()
1994    @skipCUDAIf(
1995        not _check_cusparse_spgemm_available(),
1996        "cuSparse Generic API SpGEMM is not available"
1997    )
1998    def test_addmm_all_sparse_csr(self, device, dtype, layout):
1999        M = torch.randn(10, 25, device=device).to(dtype)
2000        m1 = torch.randn(10, 50, device=device).to(dtype)
2001        m2 = torch.randn(50, 25, device=device).to(dtype)
2002        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="all_sparse")
2003
2004        # Test 0-strided
2005        M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
2006        m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
2007        m2 = torch.randn(50, 25, device=device).to(dtype)
2008        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="all_sparse")
2009
2010        # Test beta=0, M=nan
2011        M = torch.full((10, 25), float('nan'), device=device).to(dtype)
2012        m1 = torch.randn(10, 50, device=device).to(dtype)
2013        m2 = torch.randn(50, 25, device=device).to(dtype)
2014        _test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=layout, mode="all_sparse")
2015
2016        # Test transpose
2017        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
2018            def maybe_transpose(cond, m):
2019                if not cond:
2020                    return m
2021                return m.t().clone(memory_format=torch.contiguous_format).t()
2022
2023            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
2024            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
2025            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
2026            _test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=layout, mode="all_sparse")
2027
2028    @onlyCPU
2029    @skipCPUIfNoMklSparse
2030    @dtypes(*floating_and_complex_types())
2031    @sparse_compressed_nonblock_layouts()
2032    def test_addmm_dense_result(self, device, dtype, layout):
2033        M = torch.randn(10, 25, device=device).to(dtype)
2034        m1 = torch.randn(10, 50, device=device).to(dtype)
2035        m2 = torch.randn(50, 25, device=device).to(dtype)
2036        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="dense_result")
2037
2038        # Test 0-strided
2039        M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
2040        m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
2041        m2 = torch.randn(50, 25, device=device).to(dtype)
2042        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="dense_result")
2043
2044        # Test beta=0, M=nan
2045        M = torch.full((10, 25), float('nan'), device=device).to(dtype)
2046        m1 = torch.randn(10, 50, device=device).to(dtype)
2047        m2 = torch.randn(50, 25, device=device).to(dtype)
2048        _test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=layout, mode="dense_result")
2049
2050        # Test transpose
2051        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
2052            def maybe_transpose(cond, m):
2053                if not cond:
2054                    return m
2055                return m.t().clone(memory_format=torch.contiguous_format).t()
2056
2057            M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
2058            m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
2059            m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
2060            _test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=layout, mode="dense_result")
2061
2062    @parametrize("k", [0, 1, 8])
2063    @parametrize("n", [0, 1, 10])
2064    @parametrize("m", [0, 1, 25])
2065    @skipCPUIfNoMklSparse
2066    @dtypes(*floating_and_complex_types())
2067    @dtypesIfCUDA(*floating_types_and(torch.complex64,
2068                                      *[torch.bfloat16] if SM80OrLater else [],
2069                                      *[torch.half] if SM53OrLater else [],
2070                                      *[torch.complex128]
2071                                      if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
2072                                      else []))
2073    @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6,
2074                        torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8})
2075    def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k):
2076        if (TEST_WITH_ROCM and k != 0 and n != 0 and m != 0):
2077            self.skipTest("Skipped on ROCm")
2078        M = torch.randn(n, m, device=device).to(dtype)
2079        m1 = torch.randn(n, k, device=device).to(dtype)
2080        m2 = torch.randn(k, m, device=device).to(dtype)
2081        _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=torch.sparse_csr, mode="all_sparse")
2082
2083        M = torch.randn(n, m, device=device).to(dtype).to_sparse_csr()
2084        m1 = torch.randn(n, k + 1, device=device).to(dtype).to_sparse_csr()
2085        m2 = torch.randn(k, m, device=device).to(dtype).to_sparse_csr()
2086        self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2))
2087        self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2))
2088
2089    @skipCPUIfNoMklSparse
2090    @dtypes(torch.float)
2091    def test_addmm_errors(self, device, dtype):
2092        # test that the errors are the same for dense and sparse versions
2093        import re
2094
2095        def test1(*, is_sparse):
2096            # shapes must be compatible for matrix multiplication
2097            a = make_tensor((2, 3), dtype=dtype, device=device)
2098            if is_sparse:
2099                a_sparse = a.to_sparse_csr()
2100                return torch.addmm(a, a_sparse, a)
2101            else:
2102                return torch.addmm(a, a, a)
2103
2104        def test2(*, is_sparse):
2105            # mat2 must be a matrix
2106            a = make_tensor((2, 3), dtype=dtype, device=device)
2107            if is_sparse:
2108                a_sparse = a.to_sparse_csr()
2109                return torch.addmm(a, a_sparse, a.unsqueeze(0))
2110            else:
2111                return torch.addmm(a, a, a.unsqueeze(0))
2112
2113        def test3(*, is_sparse):
2114            # the first input needs to be 1D or 2D
2115            a = make_tensor((3, 3), dtype=dtype, device=device)
2116            if is_sparse:
2117                a_sparse = a.to_sparse_csr()
2118                return torch.addmm(a.unsqueeze(0), a_sparse, a)
2119            else:
2120                return torch.addmm(a.unsqueeze(0), a, a)
2121
2122        for test in (test1, test2, test3):
2123            try:
2124                test(is_sparse=False)
2125            except RuntimeError as msg:
2126                with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
2127                    test(is_sparse=True)
2128
2129    @skipCPUIfNoMklSparse
2130    @dtypes(torch.float)
2131    def test_mm_errors(self, device, dtype):
2132        # test that the errors are the same for dense and sparse versions
2133        import re
2134
2135        def test1(*, is_sparse):
2136            # shapes must be compatible for matrix multiplication
2137            a = make_tensor((2, 3), dtype=dtype, device=device)
2138            if is_sparse:
2139                a_sparse = a.to_sparse_csr()
2140                return torch.mm(a_sparse, a)
2141            else:
2142                return torch.mm(a, a)
2143
2144        def test2(*, is_sparse):
2145            # mat2 must be a matrix
2146            a = make_tensor((2, 3), dtype=dtype, device=device)
2147            if is_sparse:
2148                a_sparse = a.to_sparse_csr()
2149                return torch.mm(a_sparse, a.unsqueeze(0))
2150            else:
2151                return torch.mm(a, a.unsqueeze(0))
2152
2153        for test in (test1, test2):
2154            try:
2155                test(is_sparse=False)
2156            except RuntimeError as msg:
2157                with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))):
2158                    test(is_sparse=True)
2159
2160    @sparse_compressed_nonblock_layouts()
2161    @dtypes(torch.float, torch.double)
2162    def test_add(self, device, layout, dtype):
2163        def _test_spadd_shape(nnz, shape):
2164            # sparse.to_dense() uses torch.add internally so if torch.add is wrong,
2165            # the dense tensor will be wrong but this test would still pass
2166            # there's a separate test that checks for the correctness of the .to_dense() call
2167            x = self.genSparseCompressedTensor(shape, nnz,
2168                                               dtype=dtype,
2169                                               device=device,
2170                                               index_dtype=torch.int32,
2171                                               layout=layout,
2172                                               blocksize=())
2173            y = torch.randn(*shape, dtype=dtype, device=device)
2174            r = random.random()
2175
2176            res = torch.add(y, x, alpha=r)
2177            expected = y + r * x.to_dense()
2178            self.assertEqual(res, expected)
2179            res_perm = torch.add(x, y, alpha=r)
2180            self.assertEqual(res_perm, expected)
2181
2182            # Non contiguous dense tensor
2183            s = list(shape)
2184            s[0] = shape[-1]
2185            s[-1] = shape[0]
2186            y = torch.randn(*s, dtype=torch.double, device=device)
2187            y.transpose_(0, len(s) - 1)
2188            r = random.random()
2189
2190            res = torch.add(y, x, alpha=r)
2191            expected = y + r * x.to_dense()
2192            res_perm = torch.add(x, y, alpha=r)
2193
2194            self.assertEqual(res, expected)
2195            self.assertEqual(res_perm, expected)
2196
2197
2198        ns = [2, 5]
2199        batch_shapes = [(), (2,), (2, 3)]
2200        for b, m, n in itertools.product(batch_shapes, ns, ns):
2201            _test_spadd_shape(0, (*b, m, n))
2202            _test_spadd_shape(m * n // 2, (*b, m, n))
2203            _test_spadd_shape(m * n, (*b, m, n))
2204
2205    @dtypes(torch.float, torch.double)
2206    def test_mul(self, device, dtype):
2207        # TODO: This whole test should be migrated to OpInfos
2208        def _test_spadd_shape(fn, nnz, shape):
2209            x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2210            y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2211
2212            # Forward comparison
2213            res_sparse_sparse = fn(y, x)
2214            res_dense_sparse = fn(y.to_dense(), x)
2215            res_sparse_dense = fn(y, x.to_dense())
2216            expected = fn(y.to_dense(), x.to_dense())
2217            self.assertEqual(res_sparse_sparse, expected)
2218            # TODO: While result of mul(dense, csr) is csr, it is not fully compressed.
2219            # That means it may contain materialized zeros, since the dense argument
2220            # is converted according to the sparsity pattern of csr. In the future
2221            # we might require the result to be fully compressed.
2222            self.assertEqual(res_dense_sparse, expected)
2223            self.assertEqual(res_sparse_dense, expected)
2224
2225            # Grad comparison
2226            x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2227            y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2228            z = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
2229
2230            # csr * csr -> csr with csr, csr gradients
2231            x_a = x.clone().requires_grad_()
2232            y_a = y.clone().requires_grad_()
2233
2234            fn(y_a, x_a).backward(z)
2235
2236            x_dense_a = x.to_dense().requires_grad_()
2237            y_dense_a = y.to_dense().requires_grad_()
2238
2239            fn(y_dense_a, x_dense_a).backward(z.to_dense())
2240
2241            self.assertEqual(x_a.grad.layout, torch.sparse_csr)
2242            self.assertEqual(y_a.grad.layout, torch.sparse_csr)
2243
2244            self.assertEqual(x_a.grad.to_dense(), x_dense_a.grad)
2245            self.assertEqual(y_a.grad.to_dense(), y_dense_a.grad)
2246
2247            # TODO: Currently strided Tensors cannot have csr gradients
2248            # dense * csr -> csr with csr, dense gradients
2249            x_a = x.clone().requires_grad_()
2250            y_a = y.to_dense().clone().requires_grad_()
2251            err_msg = "Function MulBackward0 returned an invalid gradient at index 0 - expected layout Strided but got SparseCsr"
2252            with self.assertRaisesRegex(RuntimeError, err_msg):
2253                fn(y_a, x_a).backward(z)
2254
2255            # csr * dense -> csr with dense, csr gradients
2256            x_a = x.to_dense().clone().requires_grad_()
2257            y_a = y.clone().requires_grad_()
2258            err_msg = "Function MulBackward0 returned an invalid gradient at index 1 - expected layout Strided but got SparseCsr"
2259            with self.assertRaisesRegex(RuntimeError, err_msg):
2260                fn(y_a, x_a).backward(z)
2261
2262        _test_spadd_shape(torch.mul, 100, [100, 100])
2263        _test_spadd_shape(torch.mul, 0, [100, 100])
2264        _test_spadd_shape(torch.mul, 100, [100, 1])
2265        _test_spadd_shape(torch.mul, 100, [1, 100])
2266
2267    # TODO: enable hybrid once to_dense supports it
2268    @parametrize('enable_hybrid', [False])
2269    @all_sparse_compressed_layouts()
2270    @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
2271    def test_mul_scalar(self, layout, device, dtype, enable_hybrid):
2272        for sparse in self.generate_simple_inputs(
2273                layout, device=device, dtype=dtype, index_dtype=torch.int32, enable_hybrid=enable_hybrid):
2274            for scalar_dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half):
2275                # ComplexHalf is experimental
2276                if dtype is torch.half and scalar_dtype.is_complex:
2277                    continue
2278
2279                scalar_t = torch.tensor(2, dtype=scalar_dtype)
2280                for scalar in (scalar_t, scalar_t.item()):
2281                    res_out = sparse.mul(scalar)
2282                    self.assertEqual(res_out, scalar * sparse)
2283
2284                    res_dense_out = sparse.to_dense().mul(scalar)
2285                    # BUG: dispatcher ignores mul.Scalar(Tensor, Scalar)
2286                    # This issues is circumvented in the mul(Tensor, Tensor) kernel.
2287                    self.assertEqual(res_out, res_dense_out)
2288
2289                    if dtype == torch.result_type(sparse, scalar):
2290                        res_in_dense = sparse.to_dense().mul_(scalar)
2291                        res_in = sparse.clone().mul_(scalar)
2292                        self.assertEqual(res_in, res_in_dense)
2293                        self.assertEqual(res_out, res_in)
2294
2295    @skipCPUIfNoMklSparse
2296    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2297    def test_sparse_add(self, device, dtype):
2298        def run_test(m, n, index_dtype):
2299
2300            alpha = random.random()
2301            nnz1 = random.randint(0, m * n)
2302            nnz2 = random.randint(0, m * n)
2303            nnz3 = random.randint(0, m * n)
2304
2305            if TEST_WITH_ROCM:
2306                # ROCm fails when nnz = 0
2307                nnz1, nnz2, nnz3 = max(1, nnz1), max(1, nnz2), max(1, nnz3)
2308
2309            S1 = self.genSparseCSRTensor([m, n], nnz1, dtype=dtype, device=device, index_dtype=index_dtype)
2310            S2 = self.genSparseCSRTensor([m, n], nnz2, dtype=dtype, device=device, index_dtype=index_dtype)
2311            S3 = self.genSparseCSRTensor([m, n], nnz3, dtype=dtype, device=device, index_dtype=index_dtype)
2312            sparse_args = [S1, S2, S3]
2313            dense_args = [t.to_dense() for t in sparse_args]
2314            arg_idx = list(range(len(sparse_args)))
2315            out_idx = arg_idx + [None]
2316
2317            for idx1, idx2, idx3 in itertools.product(arg_idx, arg_idx, out_idx):
2318                s1 = sparse_args[idx1]
2319                s2 = sparse_args[idx2]
2320                s3 = None if idx3 is None else sparse_args[idx3]
2321                d1 = dense_args[idx1]
2322                d2 = dense_args[idx2]
2323                d3 = None if idx3 is None else dense_args[idx3]
2324
2325                expected = torch.add(d1, d2, alpha=alpha, out=d3)
2326                actual = torch.add(s1, s2, alpha=alpha, out=s3)
2327                self.assertEqual(actual.crow_indices().dtype, index_dtype)
2328                self.assertEqual(actual.col_indices().dtype, index_dtype)
2329                self.assertEqual(actual, expected)
2330                self.assertEqual(s3, d3)
2331                if s3 is not None:
2332                    self.assertEqual(s3.crow_indices().dtype, index_dtype)
2333                    self.assertEqual(s3.col_indices().dtype, index_dtype)
2334
2335        for index_dtype in [torch.int32, torch.int64]:
2336            for m, n in itertools.product([3, 5], [3, 5]):
2337                run_test(m, n, index_dtype)
2338
2339    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2340    def test_sparse_add_errors(self, device, dtype):
2341        def run_test(index_type):
2342            a = self.genSparseCSRTensor((2, 2), 3, dtype=dtype, device=device, index_dtype=index_dtype)
2343            b = self.genSparseCSRTensor((2, 1), 2, dtype=dtype, device=device, index_dtype=index_dtype)
2344            with self.assertRaisesRegex(RuntimeError, "Expected input tensors to have the same shape"):
2345                torch.add(a, b)
2346
2347        for index_dtype in [torch.int32, torch.int64]:
2348            run_test(index_dtype)
2349
2350    @skipCPUIfNoMklSparse
2351    @skipCUDAIf(
2352        not _check_cusparse_triangular_solve_available(),
2353        "cuSparse Generic API SpSV is not available"
2354    )
2355    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2356    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2357                        torch.float64: 1e-8, torch.complex128: 1e-8})
2358    def test_sparse_triangular_solve(self, device, dtype):
2359
2360        def run_test(n, k, upper, unitriangular, transpose, zero):
2361            if not unitriangular:
2362                triangle_function = torch.triu if upper else torch.tril
2363            else:
2364                # Make sure diagonal elements are not materialized.
2365                # This is to exercise `unitriangular=True` not relying on
2366                # explicit presence of these indices.
2367                if upper:
2368                    def remove_diagonal(t):
2369                        return t.triu(-1)
2370
2371                else:
2372                    def remove_diagonal(t):
2373                        return t.tril(-1)
2374
2375                triangle_function = remove_diagonal
2376
2377            make_A = torch.zeros if zero else make_tensor
2378            A = make_A((n, n), dtype=dtype, device=device)
2379            A = triangle_function(A)
2380            A_sparse = A.to_sparse_csr()
2381            B = make_tensor((n, k), dtype=dtype, device=device)
2382
2383            expected = torch.triangular_solve(B, A, upper=upper, unitriangular=unitriangular, transpose=transpose)
2384            expected_X = expected.solution
2385
2386            actual = torch.triangular_solve(B, A_sparse, upper=upper, unitriangular=unitriangular, transpose=transpose)
2387            actual_X = actual.solution
2388            actual_A_clone = actual.cloned_coefficient
2389            self.assertTrue(actual_A_clone.numel() == 0)
2390            if A_sparse._nnz() == 0:
2391                self.assertTrue(actual_X.isnan().all())
2392                return
2393            self.assertEqual(actual_X, expected_X)
2394
2395            # test out with C contiguous strides
2396            out = torch.empty_strided((n, k), (k, 1), dtype=dtype, device=device)
2397            torch.triangular_solve(
2398                B, A_sparse,
2399                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2400            )
2401            self.assertEqual(out, expected_X)
2402
2403            # test out with F contiguous strides
2404            out = torch.empty_strided((n, k), (1, n), dtype=dtype, device=device)
2405            torch.triangular_solve(
2406                B, A_sparse,
2407                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2408            )
2409            self.assertEqual(out, expected_X)
2410            self.assertEqual(out.stride(), (1, n))
2411
2412            # test out with discontiguous strides
2413            out = torch.empty_strided((2 * n, k), (1, 2 * n), dtype=dtype, device=device)[::2]
2414            if n > 0 and k > 0:
2415                self.assertFalse(out.is_contiguous())
2416                self.assertFalse(out.t().is_contiguous())
2417            before_stride = out.stride()
2418            torch.triangular_solve(
2419                B, A_sparse,
2420                upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone)
2421            )
2422            self.assertEqual(out, expected_X)
2423            self.assertEqual(out.stride(), before_stride)
2424
2425        ks = [0, 1, 3]
2426        ns = [5, 3, 0]
2427        for (k, n), (upper, unitriangular, transpose, zero) in itertools.product(itertools.product(ks, ns),
2428                                                                                 itertools.product([True, False], repeat=4)):
2429            run_test(n, k, upper, unitriangular, transpose, zero)
2430
2431    @skipCUDAIf(
2432        not _check_cusparse_sddmm_available(),
2433        "cuSparse Generic API SDDMM is not available"
2434    )
2435    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2436    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2437                        torch.float64: 1e-8, torch.complex128: 1e-8})
2438    def test_sampled_addmm(self, device, dtype):
2439        def run_test(c, a, b, op_a, op_b, *, alpha=None, beta=None):
2440            if dtype.is_complex:
2441                alpha = random.random() + 0.3j if alpha is None else alpha
2442                beta = random.random() + 0.6j if beta is None else beta
2443            else:
2444                alpha = random.random() if alpha is None else alpha
2445                beta = random.random() if beta is None else beta
2446
2447            if op_a and a.shape == b.shape:
2448                a = a.mH
2449            if op_b and a.shape == b.shape:
2450                b = b.mH
2451
2452            actual = torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta)
2453
2454            out = torch.sparse_csr_tensor(
2455                *map(torch.clone, (actual.crow_indices(), actual.col_indices())),
2456                torch.empty_like(actual.values()),
2457                size=actual.shape
2458            )
2459            torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
2460
2461            spy_c = torch.sparse_csr_tensor(c.crow_indices(), c.col_indices(), torch.ones_like(c.values()), size=c.shape)
2462            expected = alpha * (a @ b) * spy_c.to_dense() + beta * c.to_dense()
2463            self.assertEqual(actual.to_dense(), out.to_dense())
2464            self.assertEqual(actual.to_dense(), expected)
2465
2466        mnk = list(itertools.product([2, 5], repeat=3))
2467
2468        # Add a test case for size 0 a and b tensors
2469        mnk = mnk + [(5, 5, 0)]
2470
2471        batch_shapes = [(), (2,), (2, 3)]
2472        tf = [True, False]
2473        for index_dtype in [torch.int32, torch.int64]:
2474            for (m, n, k), b, noncontiguous, bcast_c in itertools.product(mnk, batch_shapes, tf, tf):
2475                if bcast_c and len(b) == 0:
2476                    continue
2477                nnz = random.randint(0, m * n)
2478                c_batch = () if bcast_c else b
2479                c = self.genSparseCSRTensor((*c_batch, m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2480                a = make_tensor((*b, m, k), dtype=dtype, device=device, noncontiguous=noncontiguous)
2481                b = make_tensor((*b, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous)
2482                for op_a, op_b in itertools.product([True, False], repeat=2):
2483                    run_test(c, a, b, op_a, op_b)
2484
2485    @skipCUDAIf(
2486        not _check_cusparse_sddmm_available(),
2487        "cuSparse Generic API SDDMM is not available"
2488    )
2489    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2490    def test_sampled_addmm_autograd(self, device, dtype):
2491        from torch.testing._internal.common_methods_invocations import sample_inputs_sparse_sampled_addmm
2492
2493        samples = list(sample_inputs_sparse_sampled_addmm(None, device, dtype, requires_grad=True))
2494
2495        for sample, dense_covector in zip(samples, [True, False]):
2496            c = sample.input
2497            a = sample.args[0]
2498            b = sample.args[1]
2499
2500            # Compute sparse result
2501            output = torch.sparse.sampled_addmm(c, a, b, **sample.kwargs)
2502            covector = torch.randn_like(output).to_dense() if dense_covector else torch.randn_like(output)
2503            output.backward(covector)
2504
2505            # Compute dense result and compare with sparse result
2506            c1, a1, b1 = (x.detach().to_dense().requires_grad_(True) for x in [c, a, b])
2507            dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1
2508            self.assertEqual(output, dense_output)
2509            dense_covector = covector.to_dense()
2510            dense_output.backward(dense_covector)
2511            self.assertEqual(c.grad, c1.grad)
2512            self.assertEqual(a.grad, a1.grad)
2513            self.assertEqual(b.grad, b1.grad)
2514
2515    @onlyCUDA
2516    # It works on ROCm and CUDA issue is currently active
2517    @skipCUDAIf(not TEST_WITH_ROCM, "Causes CUDA memory exception, see https://github.com/pytorch/pytorch/issues/72177")
2518    @skipCUDAIf(
2519        not _check_cusparse_sddmm_available(),
2520        "cuSparse Generic API SDDMM is not available"
2521    )
2522    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2523    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
2524                        torch.float64: 1e-8, torch.complex128: 1e-8})
2525    def test_sampled_addmm_zero_sized(self, device, dtype):
2526        def run_test(c, a, b):
2527            actual = torch.sparse.sampled_addmm(c, a, b)
2528            self.assertEqual(actual.shape, c.shape)
2529
2530        for m, n, k in itertools.product([0, 5], repeat=3):
2531            c = torch.empty(m, n, dtype=dtype, device=device, layout=torch.sparse_csr)
2532            a = make_tensor((m, k), dtype=dtype, device=device)
2533            b = make_tensor((k, n), dtype=dtype, device=device)
2534            run_test(c, a, b)
2535
2536    @onlyCUDA
2537    @skipCUDAIf(
2538        not _check_cusparse_sddmm_available(),
2539        "cuSparse Generic API SDDMM is not available"
2540    )
2541    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
2542    def test_sampled_addmm_errors(self, device, dtype):
2543        # test that the errors are the same for dense and sparse sampled versions
2544        # import re
2545
2546        # shapes must be compatible for matrix multiplication
2547        a = make_tensor((2, 3), dtype=dtype, device=device)
2548        a_sparse = a.to_sparse_csr()
2549        with self.assertRaisesRegex(RuntimeError, r"cannot be multiplied"):
2550            torch.sparse.sampled_addmm(a_sparse, a, a)
2551
2552        # mat1 must be a matrix
2553        with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to be a matrix"):
2554            torch.sparse.sampled_addmm(a_sparse, a[..., 0, :], a)
2555
2556        # mat2 must be a matrix
2557        with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to be a matrix"):
2558            torch.sparse.sampled_addmm(a_sparse, a, a[..., 0, :])
2559
2560        a = make_tensor((2, 2), dtype=dtype, device=device)
2561        b = make_tensor((3, 3), dtype=dtype, device=device)
2562        b_sparse = b.to_sparse_csr()
2563        with self.assertRaisesRegex(RuntimeError, r"self.shape\[-2\] must match mat1.shape\[-2\]"):
2564            torch.sparse.sampled_addmm(b_sparse, a, a)
2565
2566        b = make_tensor((2, 3), dtype=dtype, device=device)
2567        b_sparse = b.to_sparse_csr()
2568        with self.assertRaisesRegex(RuntimeError, r"self.shape\[-1\] must match mat2.shape\[-1\]"):
2569            torch.sparse.sampled_addmm(b_sparse, a, a)
2570
2571        a = make_tensor((2, 2), dtype=dtype, device=device)
2572        a_sparse = a.to_sparse_csr()
2573        with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to have strided layout"):
2574            torch.sparse.sampled_addmm(a_sparse, a_sparse, a_sparse)
2575
2576        with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to have strided layout"):
2577            torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
2578
2579    @onlyCPU
2580    @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
2581    @precisionOverride({torch.bfloat16: 0.01})
2582    def test_sparse_mm_reduce_sum(self, device, dtype):
2583        def run_test(m, n, k, nnz, train):
2584            sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64)
2585            dense = sparse.to_dense()
2586
2587            mat = torch.randn(k, n, dtype=dtype)
2588            ref_mat = mat.clone()
2589
2590            if train:
2591                sparse.requires_grad_()
2592                mat.requires_grad_()
2593                dense.requires_grad_()
2594                ref_mat.requires_grad_()
2595
2596            ref_out = torch.mm(dense, ref_mat)
2597            out = torch.sparse.mm(sparse, mat, 'sum')
2598
2599            self.assertEqual(out, ref_out)
2600
2601            if train:
2602                ref_out.sum().backward()
2603                out.sum().backward()
2604
2605                grad_input = sparse.grad
2606                ref_grad_input = dense.grad
2607                grad_mat = mat.grad
2608                ref_grad_mat = ref_mat.grad
2609
2610                self.assertEqual(grad_input.to_dense(), ref_grad_input)
2611                self.assertEqual(grad_mat, ref_grad_mat)
2612
2613        run_test(4, 5, 4, 10, False)
2614        run_test(4, 4, 4, 16, True)
2615
2616    @skipIfTorchDynamo()
2617    @onlyCPU
2618    @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
2619    @precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
2620    def test_sparse_mm_reduce(self, device, dtype):
2621        def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
2622            csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2623            mat = torch.randn(n, k, dtype=dtype)
2624            ref_mat = mat.clone()
2625            ref_values = csr.values().clone()
2626
2627            out_int32 = index_dtype == torch.int32
2628            coo_indices = torch._convert_indices_from_csr_to_coo(
2629                csr.crow_indices(),
2630                csr.col_indices(),
2631                out_int32=out_int32)
2632            row, col = coo_indices[0], coo_indices[1]
2633
2634            def ref(row, col, val, mat):
2635                out = torch.zeros([m, k], dtype=dtype)
2636                weight = mat.index_select(0, col)
2637                src = weight.mul(val.view(-1, 1))
2638                index = row.view(-1, 1).expand_as(weight)
2639                index = index.to(dtype=torch.int64)
2640                # scatter_reduce expect index to be int64
2641                out.scatter_reduce_(0, index, src, reduce=reduce_type, include_self=False)
2642                return out
2643
2644            if train:
2645                csr.requires_grad_()
2646                mat.requires_grad_()
2647                ref_values.requires_grad_()
2648                ref_mat.requires_grad_()
2649
2650            ref_out = ref(row, col, ref_values, ref_mat)
2651            out = torch.sparse.mm(csr, mat, reduce_type)
2652            self.assertEqual(out, ref_out)
2653
2654            if train and dtype not in (torch.bfloat16, torch.float16):
2655                ref_out.sum().backward()
2656                out.sum().backward()
2657
2658                grad_values = csr.grad.values()
2659                grad_weight = mat.grad
2660                ref_grad_values = ref_values.grad
2661                ref_grad_weight = ref_mat.grad
2662                self.assertEqual(grad_values, ref_grad_values)
2663                self.assertEqual(grad_weight, ref_grad_weight)
2664
2665        for train in [False, True]:
2666            for index_dtype in [torch.int32, torch.int64]:
2667                for reduce_type in ["sum", "mean", "amax", "amin"]:
2668                    # by setting nnz < M, create empty rows
2669                    run_test(3, 4, 11, 1, reduce_type, index_dtype, train)
2670                    run_test(3, 4, 11, 6, reduce_type, index_dtype, train)
2671                    run_test(3, 4, 11, 12, reduce_type, index_dtype, train)
2672                    # we are doing blocking with 4x vector length in the kernel,
2673                    # so need to test when K > 4x vector length
2674                    run_test(4, 7, 33, 13, reduce_type, index_dtype, train)
2675
2676    @skipMeta
2677    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2678    def test_coo_csr_conversion(self, device, dtype):
2679        for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2680            size = (m, n)
2681            dense = make_tensor(size, dtype=dtype, device=device)
2682            coo_sparse = dense.to_sparse()
2683            csr_sparse = coo_sparse.to_sparse_csr()
2684
2685            self.assertEqual(csr_sparse.to_dense(), dense)
2686
2687    @skipMeta
2688    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2689    def test_csr_coo_conversion(self, device, dtype):
2690        for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2691            size = (m, n)
2692            dense = make_tensor(size, dtype=dtype, device=device)
2693            csr_sparse = dense.to_sparse_csr()
2694            coo_sparse = csr_sparse.to_sparse()
2695
2696            self.assertEqual(coo_sparse.to_dense(), dense)
2697
2698    # Currently, there is no rule in PyTorch for filling zeros in the outputs
2699    #   from operations on Sparse CSR tensors. Hence only those operators are supported
2700    #   which have 0->0 correspondence, example: sin(0) = 0, tan(0) = 0 but
2701    #   cos(0) = 1 (and hence it's not supported).
2702    # Note: here, we do this test only for unary operators
2703    @ops(sparse_csr_unary_ufuncs)
2704    def test_zero_to_zero_correspondence_unary(self, device, dtype, op):
2705        zero = torch.zeros((1, 2), dtype=dtype, device=device)
2706        tensor_explicit_zeros = torch.sparse_csr_tensor([0, 1], [1], [0], dtype=dtype, device=device)
2707
2708        output_zero = op(zero)
2709        expected_zero = zero.to(output_zero.dtype)
2710
2711        output_explicit_zeros = op(tensor_explicit_zeros).to_dense()
2712        expected_explicit_zeros = tensor_explicit_zeros.to_dense().to(output_explicit_zeros.dtype)
2713
2714        for (output, expected) in [
2715                (output_zero, expected_zero),
2716                (output_explicit_zeros, expected_explicit_zeros)
2717        ]:
2718            self.assertEqual(output, expected, f"This operator ({op.name}) should not be supported for "
2719                             "Sparse CSR as it breaks 0->0 correspondence.")
2720
2721        for inp in [zero.to_sparse_csr(), tensor_explicit_zeros]:
2722            self.assertEqual(op(inp).values().numel(), inp.values().numel(),
2723                             f"{op.name} fails to preserve sparsity pattern.")
2724
2725    @ops(sparse_csr_unary_ufuncs)
2726    def test_sparse_csr_unary_out(self, device, dtype, op):
2727        samples = op.sample_inputs(device, dtype)
2728
2729        if not op.supports_out:
2730            self.skipTest("Skipped! Out not supported")
2731
2732        for sample in samples:
2733            assert torch.is_tensor(sample.input)
2734            # Sparse CSR only supports 2D tensors as inputs
2735            # Fail early to prevent silent success with this test
2736            if sample.input.ndim != 2:
2737                raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.")
2738
2739            sample.input = sample.input.to_sparse_csr()
2740            expect = op(sample.input, *sample.args, **sample.kwargs)
2741
2742            out = self.genSparseCSRTensor(sample.input.size(), sample.input._nnz(),
2743                                          device=sample.input.device, dtype=expect.dtype,
2744                                          index_dtype=sample.input.crow_indices().dtype)
2745            op(sample.input, *sample.args, **sample.kwargs, out=out)
2746
2747            self.assertEqual(out, expect)
2748
2749    @ops(sparse_csr_unary_ufuncs)
2750    def test_sparse_csr_unary_inplace(self, device, dtype, op):
2751        samples = op.sample_inputs(device, dtype)
2752
2753        if op.inplace_variant is None:
2754            self.skipTest("Skipped! Inplace variant not supported!")
2755
2756        for sample in samples:
2757            assert torch.is_tensor(sample.input)
2758            # Sparse CSR only supports 2D tensors as inputs
2759            # Fail early to prevent silent success with this test
2760            if sample.input.ndim != 2:
2761                raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.")
2762
2763            sample.input = sample.input.to_sparse_csr()
2764            expect = op(sample.input, *sample.args, **sample.kwargs)
2765
2766            if not torch.can_cast(expect.dtype, dtype):
2767                with self.assertRaisesRegex(RuntimeError, "result type"):
2768                    op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2769                continue
2770
2771            if sample.input.is_complex() and op.name == "abs":
2772                with self.assertRaisesRegex(RuntimeError, "not supported"):
2773                    op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2774                continue
2775
2776            actual = op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
2777
2778            self.assertIs(actual, sample.input)
2779            self.assertEqual(actual, expect)
2780
2781    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2782    @ops(sparse_csr_unary_ufuncs, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble])
2783    def test_autograd_sparse_csr_unary(self, device, dtype, op):
2784        if op.name not in UNARY_EWISE_CSR_ALLOW_AUTOGRAD:
2785            self.skipTest(f"Skipped! Unary op {op.name} not supported with CSR input and autograd")
2786
2787        samples = list(op.sample_inputs(device, dtype))
2788
2789        # Fail early to prevent silent success with this test
2790        ndims_equals_2d = (s.input.ndim == 2 for s in samples)
2791        if not any(ndims_equals_2d):
2792            raise ValueError("Expected at least one 2D tensor in samples.")
2793
2794        for sample in samples:
2795            # We must skip samples of low dimensionality, we can't covert them to sparsed compressed layouts
2796            if sample.input.ndim < 2:
2797                continue
2798            sparse_input = sample.input.to_sparse_csr().requires_grad_(True)
2799
2800            def fn(input):
2801                output = op.gradcheck_wrapper(op.get_op(), input, *sample.args, **sample.kwargs)
2802                if sample.output_process_fn_grad is not None:
2803                    return sample.output_process_fn_grad(output)
2804                return output
2805
2806            # Compute sparse result
2807            output = fn(sparse_input)
2808            covector = torch.randn_like(output)
2809            output.backward(covector)
2810            self.assertTrue(torch.is_tensor(sparse_input.grad))
2811            self.assertTrue(sparse_input.grad.is_sparse_csr)
2812
2813            # Compute dense result and compare with sparse result
2814            dense_input = sparse_input.detach().to_dense().requires_grad_(True)
2815            dense_output = fn(dense_input)
2816            dense_covector = covector.to_dense()
2817            dense_output.backward(dense_covector)
2818            self.assertEqual(sparse_input.grad, dense_input.grad)
2819
2820    @skipCUDAIf(
2821        not _check_cusparse_sddmm_available(),
2822        "cuSparse Generic API SDDMM is not available"
2823    )
2824    @dtypes(torch.float64)
2825    def test_autograd_dense_output_addmm(self, device, dtype):
2826        from torch.testing._internal.common_methods_invocations import sample_inputs_addmm
2827
2828        samples = list(sample_inputs_addmm(None, device, dtype, requires_grad=True))
2829
2830        # Fail early to prevent silent success with this test
2831        ndims_equals_2d = (s.args[0].ndim == 2 for s in samples)
2832        if not any(ndims_equals_2d):
2833            raise ValueError("Expected at least one 2D tensor in samples to convert to sparse.")
2834
2835        for sample in samples:
2836            a = sample.args[0].relu().to_sparse_csr()
2837            if sample.args[0].shape == sample.args[1].shape:
2838                import warnings
2839                warnings.warn("Broken for square matrices, see https://github.com/pytorch/pytorch/issues/116565")
2840                continue
2841
2842            # This path tests the autograd path wrt dense inputs
2843            for addmm in [torch.addmm, torch.sparse.addmm]:
2844
2845                def fn(c, b):
2846                    output = addmm(c, a, b, **sample.kwargs)
2847                    if sample.output_process_fn_grad is not None:
2848                        return sample.output_process_fn_grad(output)
2849                    return output
2850
2851                self.assertTrue(torch.autograd.gradcheck(fn, [sample.input, sample.args[1]], fast_mode=True))
2852
2853                # noncontiguous
2854                c = make_tensor(sample.input.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2855                b = make_tensor(sample.args[1].shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2856                self.assertTrue(torch.autograd.gradcheck(fn, [c, b], fast_mode=True))
2857
2858                # Now test the autograd path wrt sparse inputs
2859                for reverse in [True, False]:
2860                    c, b = sample.input, sample.args[1]
2861                    if reverse and a.shape != b.shape:
2862                        continue
2863
2864                    def fn(a):
2865                        inputs = (c, b, a) if reverse else (c, a, b)
2866                        output = addmm(*inputs, **sample.kwargs)
2867                        if sample.output_process_fn_grad is not None:
2868                            return sample.output_process_fn_grad(output)
2869                        return output
2870
2871                    # gradcheck doesn't work for sparse CSR yet, compare against dense path
2872                    # Compute sparse result
2873                    a = a.detach().requires_grad_(True)
2874                    output = fn(a)
2875                    covector = torch.randn_like(output)
2876                    output.backward(covector)
2877                    self.assertTrue(torch.is_tensor(a.grad))
2878                    if addmm == torch.sparse.addmm:
2879                        self.assertTrue(a.grad.is_sparse_csr)
2880                    else:
2881                        self.assertTrue(a.grad.layout == torch.strided)
2882
2883                    # Compute dense result and compare with sparse result
2884                    dense_a = a.detach().to_dense().requires_grad_(True)
2885                    dense_output = fn(dense_a)
2886                    self.assertEqual(output, dense_output)
2887                    dense_covector = covector.to_dense()
2888                    dense_output.backward(dense_covector)
2889
2890                    if addmm == torch.sparse.addmm:
2891                        self.assertEqual(a.grad, dense_a.grad.sparse_mask(a))
2892                    else:
2893                        self.assertEqual(a.grad, dense_a.grad)
2894
2895    @skipCPUIfNoMklSparse
2896    @dtypes(torch.float64)
2897    def test_autograd_dense_output_addmv(self, device, dtype):
2898        from torch.testing._internal.common_methods_invocations import sample_inputs_addmv
2899
2900        samples = list(sample_inputs_addmv(None, device, dtype, requires_grad=True))
2901
2902        # Fail early to prevent silent success with this test
2903        ndims_equals_2d = (s.args[0].ndim == 2 for s in samples)
2904        if not any(ndims_equals_2d):
2905            raise ValueError("Expected at least one 2D tensor in samples to convert to sparse.")
2906
2907        for sample in samples:
2908            # TODO: Remove detach once we have autograd support for CSR input
2909            a = sample.args[0].to_sparse_csr().detach()
2910
2911            def fn(c, b):
2912                output = torch.addmv(c, a, b, **sample.kwargs)
2913                if sample.output_process_fn_grad is not None:
2914                    return sample.output_process_fn_grad(output)
2915                return output
2916
2917            self.assertTrue(torch.autograd.gradcheck(fn, [sample.input, sample.args[1]], fast_mode=True))
2918
2919            # noncontiguous
2920            c = make_tensor(sample.input.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2921            b = make_tensor(sample.args[1].shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True)
2922            self.assertTrue(torch.autograd.gradcheck(fn, [c, b], fast_mode=True))
2923
2924    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
2925    @ops(binary_ops_with_dense_output, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, ])
2926    def test_autograd_dense_output(self, device, dtype, op):
2927        if op.name == "mv" and no_mkl_sparse and self.device_type == 'cpu':
2928            self.skipTest("MKL Sparse is not available")
2929
2930        samples = list(op.sample_inputs(device, dtype, requires_grad=True))
2931
2932        # Fail early to prevent silent success with this test
2933        ndims_equals_2d = (s.input.ndim == 2 for s in samples)
2934        if not any(ndims_equals_2d):
2935            raise ValueError("Expected at least one 2D tensor in samples.")
2936
2937        # Here we assume that the signature is op(sparse_input, dense_input) -> dense_output
2938        for sample in samples:
2939            # TODO: Remove detach once we have autograd support for CSR input
2940            sparse_input = sample.input.to_sparse_csr().detach()
2941
2942            def fn(*args):
2943                output = op.gradcheck_wrapper(op.get_op(), sparse_input, *args, **sample.kwargs)
2944                if sample.output_process_fn_grad is not None:
2945                    return sample.output_process_fn_grad(output)
2946                return output
2947
2948            self.assertTrue(torch.autograd.gradcheck(fn, sample.args, fast_mode=True))
2949
2950            # noncontiguous
2951            args = [make_tensor(a.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True) for a in sample.args]
2952            self.assertTrue(torch.autograd.gradcheck(fn, args, fast_mode=True))
2953
2954    @dtypes(*all_types_and_complex())
2955    def test_direct_coo_csr_conversion(self, device, dtype):
2956        for m, n in itertools.product([5, 2, 0], [5, 2, 0]):
2957            size = (m, n)
2958            dense = make_tensor(size, dtype=dtype, device=device)
2959            coo_sparse = dense.to_sparse_coo()
2960
2961            self.assertEqual(coo_sparse.to_sparse_csr().to_sparse_coo(), coo_sparse)
2962
2963    @skipMeta
2964    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2965    def test_sum(self, device, dtype):
2966        def run_test(shape, nnz, index_type):
2967            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
2968            self.assertEqual(a.sum(), a.values().sum())
2969            if dtype in floating_types():
2970                a.requires_grad_(True)
2971                a.sum().backward()
2972                self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device))
2973        for shape, index_dtype in itertools.product(
2974                [(10, 5), (10, 10)],
2975                [torch.int32, torch.int64]):
2976            run_test(shape, 0, index_dtype)
2977            run_test(shape, max(shape), index_dtype)
2978            run_test(shape, shape[0] * shape[1], index_dtype)
2979
2980    @skipIfTorchDynamo()
2981    @skipMeta
2982    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
2983    @all_sparse_compressed_layouts()
2984    def test_transpose(self, device, dtype, layout):
2985
2986        def _check_transpose_view(subject, transpose):
2987            self.assertTrue(transpose.values()._is_view())
2988            self.assertTrue(transpose._is_view())
2989            self.assertTrue(transpose._base is subject)
2990
2991        def _check_layout_invariants(transpose):
2992            self.assertEqual(transpose.device, torch.device(device))
2993            compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[transpose.layout]
2994            compressed_indices, plain_indices = compressed_indices_mth(transpose), plain_indices_mth(transpose)
2995            torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, transpose.values(),
2996                                                          transpose.shape, transpose.layout)
2997
2998        def check_good_transpose(subject, subject_dense, dim0, dim1, expected_layout):
2999            transpose = subject.transpose(dim0, dim1)
3000            # correct layout
3001            self.assertEqual(transpose.layout, expected_layout)
3002            # transpose must be return a view
3003            _check_transpose_view(subject, transpose)
3004            # result uses unsafe construction, so we check invariants
3005            _check_layout_invariants(transpose)
3006            self.assertEqual(transpose.to_dense(), subject_dense.transpose(dim0, dim1))
3007
3008            round_trip = transpose.transpose(dim0, dim1)
3009            self.assertEqual(round_trip.layout, subject.layout)
3010            # transpose must be return a view
3011            _check_transpose_view(subject, round_trip)
3012            # result uses unsafe construction, so we check invariants
3013            _check_layout_invariants(round_trip)
3014            self.assertEqual(round_trip.to_dense(), subject_dense)
3015
3016        def check_same_dim_transpose(subject, subject_dense, dim):
3017            transpose = subject.transpose(dim, dim)
3018            # correct layout
3019            self.assertEqual(transpose.layout, subject.layout)
3020            # transpose must be return a view
3021            _check_transpose_view(subject, transpose)
3022            # result uses unsafe construction, so we check invariants
3023            _check_layout_invariants(transpose)
3024            self.assertEqual(transpose.to_dense(), subject_dense)
3025
3026        def check_dim_type_mismatch_throws(subject, name0, dim0, name1, dim1):
3027            mismatch_name = f"{dim0}\\({name0}\\) and {dim1}\\({name1}\\)"
3028            err = r"transpose\(\): can only transpose dimensions of the same type \(Batch, Sparse, Dense\), got " + mismatch_name
3029
3030            with self.assertRaisesRegex(RuntimeError, err):
3031                subject.transpose(dim0, dim1)
3032
3033        def run_test(shape, nnz, index_type, n_dense, blocksize=()):
3034            subject = self.genSparseCompressedTensor(shape,
3035                                                     nnz,
3036                                                     layout=layout,
3037                                                     device=device,
3038                                                     index_dtype=index_type,
3039                                                     blocksize=blocksize,
3040                                                     dense_dims=n_dense,
3041                                                     dtype=dtype)
3042
3043
3044            sparse0 = len(shape) - n_dense - 1
3045            sparse1 = sparse0 - 1
3046
3047            dense0 = sparse0 + 1 if n_dense > 0 else None
3048            dense1 = dense0 + 1 if n_dense > 1 else None
3049
3050            n_batch = len(shape) - n_dense - 2
3051            batch0 = sparse1 - 1 if n_batch > 0 else None
3052            batch1 = 0 if n_batch > 1 else None
3053
3054            sparse_dims = (sparse0, sparse1)
3055            dense_dims = (dense0, dense1)
3056            batch_dims = (batch0, batch1)
3057
3058            named0 = [(name, d[0]) for name, d in zip(["Batch", "Sparse", "Dense"], (batch_dims, sparse_dims, dense_dims))]
3059            named1 = [(name, d[1]) for name, d in zip(["Batch", "Sparse", "Dense"], (batch_dims, sparse_dims, dense_dims))]
3060
3061            flipped_layout = {
3062                torch.sparse_csr: torch.sparse_csc,
3063                torch.sparse_csc: torch.sparse_csr,
3064                torch.sparse_bsr: torch.sparse_bsc,
3065                torch.sparse_bsc: torch.sparse_bsr
3066            }[layout]
3067            if n_dense > 0:
3068                # expect all transpose to throw
3069                for (name0, dim0), (name1, dim1) in itertools.product(named0, named1):
3070                    msg = r"transpose\(\): hybrid sparse compressed tensors with dense dimensions are not supported"
3071                    if (dim0 is not None) and (dim1 is not None):
3072                        with self.assertRaisesRegex(RuntimeError, msg):
3073                            subject.transpose(dim0, dim1)
3074            else:
3075                subject_dense = subject.to_dense()
3076                for (name0, dim0), (name1, dim1) in itertools.product(named0, named1):
3077                    if dim0 is not None:
3078                        check_same_dim_transpose(subject, subject_dense, dim0)
3079
3080                        if dim1 is not None:
3081                            if name0 == name1:
3082                                expected_layout = flipped_layout if name0 == "Sparse" else layout
3083                                check_good_transpose(subject, subject_dense, dim0, dim1, expected_layout)
3084                            else:
3085                                check_dim_type_mismatch_throws(subject, name0, dim0, name1, dim1)
3086
3087        # batch/sparse, sparse/dense only and full hybrid cases
3088        shape_ndense = list(itertools.product([(2, 4, 6, 2), (10, 6, 4, 2), (2, 4, 4, 2, 6)], [0, 1, 2]))
3089        # sparse only cases
3090        shape_ndense += [[(4, 8), 0], [(2, 2), 0], [(8, 4), 0]]
3091        for (shape, n_dense), index_dtype in itertools.product(shape_ndense, [torch.int32, torch.int64]):
3092            n_batch = len(shape) - n_dense - 2
3093            sparse_shape = shape[n_batch: n_batch + 2]
3094            if layout in (torch.sparse_bsr, torch.sparse_bsc):
3095                # for blocked all combinations of 2,1 should be valid blocksizes
3096                run_test(shape, 0, index_dtype, n_dense, blocksize=(2, 2))
3097                run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(2, 2))
3098                run_test(shape, sparse_shape[0] * sparse_shape[1], index_dtype, n_dense, blocksize=(2, 2))
3099                # repeat the realistic sparseity case with varried block sizes
3100                run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(2, 1))
3101                run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(1, 2))
3102                run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(1, 1))
3103            else:
3104                run_test(shape, 0, index_dtype, n_dense)
3105                run_test(shape, max(sparse_shape), index_dtype, n_dense)
3106                run_test(shape, sparse_shape[0] * sparse_shape[1], index_dtype, n_dense)
3107
3108    # TODO: This is a stopgap for a rigorous extension of our autograd tests
3109    # to test the functionality of detach
3110    @skipMeta
3111    @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
3112    def test_exercise_detach(self, device, dtype):
3113        shape = (3, 3)
3114        nnz = 4
3115        for index_dtype in [torch.int32, torch.int64]:
3116            inp = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
3117            detached_inp = inp.detach()
3118            self.assertEqual(inp, detached_inp)
3119
3120    def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)):
3121        if tensor.layout in [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.strided]:
3122            tensor = tensor.to_dense()
3123        else:
3124            raise NotImplementedError(repr(tensor))
3125        if layout is torch.sparse_csr:
3126            return sp.csr_matrix(tensor.cpu().numpy())
3127        if layout is torch.sparse_csc:
3128            return sp.csc_matrix(tensor.cpu().numpy())
3129        if layout is torch.sparse_bsr:
3130            return sp.bsr_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices()
3131        if layout is torch.sparse_bsc:
3132            # SciPy doesn't have native BSC support - but our tests don't need the full
3133            # functionality so fake it by using a transposed BSR matrix.
3134            class FakeBscMatrix:
3135                def __init__(self, matrix):
3136                    self._matrix = matrix
3137                    self.shape = tuple(reversed(matrix.shape))
3138                    self.indptr = matrix.indptr
3139                    self.indices = matrix.indices
3140                    self.data = [x.transpose() for x in matrix.data]
3141
3142                @staticmethod
3143                def from_matrix(matrix, blocksize):
3144                    blocksize = tuple(reversed(blocksize))
3145                    matrix = matrix.transpose()
3146                    return FakeBscMatrix(sp.bsr_matrix(matrix, blocksize=blocksize))
3147
3148                def sorted_indices(self):
3149                    sub = self._matrix.sorted_indices()
3150                    return FakeBscMatrix(sub)
3151
3152            return FakeBscMatrix.from_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices()
3153        raise NotImplementedError(repr(tensor))
3154
3155    @skipMeta
3156    @all_sparse_compressed_layouts('to_layout')
3157    @all_sparse_compressed_layouts('from_layout')
3158    def test_compressed_layout_conversions_coverage(self, device, from_layout, to_layout):
3159        """This test performs a smoke test for covered conversion and verifies
3160        that an exception is thrown for unsupported conversions.
3161
3162        TODO: This test covers a subset of
3163        TestSparseAny.test_to_sparse tests and can be
3164        eliminated. Keeping the test until the new
3165        `Tensor.to_sparse(*, layout, blocksize)` has landed.
3166        """
3167
3168        allowed_pairwise_layouts_sets = {
3169            frozenset({torch.sparse_csc}),
3170            frozenset({torch.sparse_csr}),
3171            frozenset({torch.sparse_csc, torch.sparse_csr}),
3172            frozenset({torch.sparse_csc, torch.sparse_bsc}),
3173            frozenset({torch.sparse_csc, torch.sparse_bsr}),
3174            frozenset({torch.sparse_csr, torch.sparse_bsc}),
3175            frozenset({torch.sparse_csr, torch.sparse_bsr}),
3176            frozenset({torch.sparse_bsc}),
3177            frozenset({torch.sparse_bsr}),
3178            frozenset({torch.sparse_bsc, torch.sparse_bsr}),
3179        }
3180        block_layouts = (torch.sparse_bsr, torch.sparse_bsc)
3181
3182        def _to_from_layout(layout_a, layout_b, a):
3183            expect_error = True
3184            if {layout_a, layout_b} in allowed_pairwise_layouts_sets:
3185                expect_error = False
3186
3187            # BSR -> CSR is not yet supported
3188            if (layout_a, layout_b) == (torch.sparse_bsr, torch.sparse_csr):
3189                expect_error = True
3190            # BSR -> CSC is not yet supported
3191            if (layout_a, layout_b) == (torch.sparse_bsr, torch.sparse_csc):
3192                expect_error = True
3193            # BSC -> CSR is not yet supported
3194            if (layout_a, layout_b) == (torch.sparse_bsc, torch.sparse_csr):
3195                expect_error = True
3196            # BSC -> CSC is not yet supported
3197            if (layout_a, layout_b) == (torch.sparse_bsc, torch.sparse_csc):
3198                expect_error = True
3199            # CSR -> BSR only works for non-batched inputs
3200            if (layout_a, layout_b) == (torch.sparse_csr, torch.sparse_bsr):
3201                if a.dim() > 2:
3202                    expect_error = True
3203            # CSR -> BSC only works for non-batched inputs
3204            if (layout_a, layout_b) == (torch.sparse_csr, torch.sparse_bsc):
3205                if a.dim() > 2:
3206                    expect_error = True
3207            # CSC -> BSR only works for non-batched inputs
3208            if (layout_a, layout_b) == (torch.sparse_csc, torch.sparse_bsr):
3209                if a.dim() > 2:
3210                    expect_error = True
3211            # CSC -> BSC only works for non-batched inputs
3212            if (layout_a, layout_b) == (torch.sparse_csc, torch.sparse_bsc):
3213                if a.dim() > 2:
3214                    expect_error = True
3215
3216            blocksize_a = (1, 1) if layout_a in {torch.sparse_bsr, torch.sparse_bsc} else None
3217            blocksize_b = (1, 1) if layout_b in {torch.sparse_bsr, torch.sparse_bsc} else None
3218            b = a.to_sparse(layout=layout_a, blocksize=blocksize_a)
3219            if expect_error:
3220                with self.assertRaises(RuntimeError):
3221                    b.to_sparse(layout=layout_b, blocksize=blocksize_b)
3222            else:
3223                c = b.to_sparse(layout=layout_b, blocksize=blocksize_b)
3224                self.assertEqual(a.to_dense(), c.to_dense())
3225
3226                # change of blocksize upon conversion is not yet supported.
3227                if b.layout in block_layouts:
3228                    for block_layout in block_layouts:
3229                        with self.assertRaisesRegex(RuntimeError,
3230                                                    "conversion from.*to.*with blocksize changed from.*to.*is not supported"):
3231                            b.to_sparse(layout=block_layout, blocksize=(3, 3))
3232
3233        batch_dims = [(), (2,), (2, 2), (2, 2, 2)]
3234        sparse_dims = (6, 12)
3235        for batch_dim in batch_dims:
3236            a = make_tensor(batch_dim + sparse_dims, dtype=torch.float, device=device)
3237            _to_from_layout(from_layout, to_layout, a)
3238
3239    @skipMeta
3240    @all_sparse_compressed_layouts()
3241    @batched_nonbatched()
3242    @hybrid_nonhybrid()
3243    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
3244    def test_dense_to_from_sparse_compressed(self, device, hybrid, batched, layout):
3245        """This test tests conversion from dense to/from CSR and CSC
3246        by comparing to SciPy's implementation.
3247
3248        Here we test only those conversion combinations that SciPy
3249        supports to ensure that PyTorch conversions are in the same
3250        page with SciPy.  Independent from SciPy, all conversion
3251        combinations are tested in TestSparseAny.test_to_sparse.
3252        """
3253
3254        blocked_layouts = (torch.sparse_bsr, torch.sparse_bsc)
3255
3256        # helpers
3257
3258        def _check_against_scipy_matrix(pt_matrix, dense, blocksize, **kwargs):
3259            # scipy has no bsc layout, so we check against the bsr layout of the tranposed dense
3260            if layout == torch.sparse_bsc:
3261                sp_matrix = self._construct_sp_matrix(dense.t(), layout=torch.sparse_bsr, blocksize=blocksize[::-1])
3262            else:
3263                sp_matrix = self._construct_sp_matrix(dense, layout=layout, blocksize=blocksize)
3264
3265            compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3266
3267            self.assertEqual(layout, pt_matrix.layout)
3268            if layout == torch.sparse_bsc:
3269                self.assertEqual(sp_matrix.shape[::-1], pt_matrix.shape)
3270            else:
3271                self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3272
3273            self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3274            self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3275            if layout == torch.sparse_bsc:
3276                # we must tranpose the blocks before comparing
3277                self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values().transpose(-2, -1))
3278            else:
3279                self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3280
3281        def _check_hybrid_matrix(pt_matrix, dense, blocksize, **kwargs):
3282            # Calculate COO indices for sparse matrix.
3283            compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3284            compressed_indices = compressed_indices_mth(pt_matrix)
3285            plain_indices = plain_indices_mth(pt_matrix)
3286            coo_indices = torch._convert_indices_from_csr_to_coo(compressed_indices, plain_indices)
3287            row_indices, col_indices = {
3288                torch.sparse_csr: (coo_indices[0, ], coo_indices[1, ]),
3289                torch.sparse_csc: (coo_indices[1, ], coo_indices[0, ]),
3290                torch.sparse_bsr: (coo_indices[0, ], coo_indices[1, ]),
3291                torch.sparse_bsc: (coo_indices[1, ], coo_indices[0, ]),
3292            }[pt_matrix.layout]
3293
3294            # If sparse matrix layout blocked, rearrange dense matrix
3295            # so that the shape past first two dimensions match the
3296            # shape of sparse matrix values.
3297            dense_to_check = dense
3298            if blocksize:
3299                dense_shape = dense.shape
3300                dense_to_check_shape = (dense.shape[0] // blocksize[0],
3301                                        blocksize[0],
3302                                        dense.shape[1] // blocksize[1],
3303                                        blocksize[1]) + dense.shape[2:]
3304                dense_to_check = dense_to_check.reshape(dense_to_check_shape).transpose(1, 2)
3305
3306            # Verify that non-zero values of the sparse matrix are
3307            # equal to corresponding values of the dense matrix.
3308            self.assertEqual(pt_matrix.values(), dense_to_check[row_indices, col_indices])
3309
3310            # Verify that the remaining elements of the dense matrix
3311            # are 0, i.e. that dense are sparse matrix are fully
3312            # equal.
3313            mask = torch.ones_like(dense_to_check, dtype=torch.bool)
3314            mask[row_indices, col_indices] = False
3315            self.assertTrue(torch.all(torch.masked_select(dense_to_check, mask) == 0))
3316
3317        def _check_batched(pt_tensor, dense, check_batch=None, batch_shape=(), blocksize=(), **kwargs):
3318            self.assertEqual(layout, pt_tensor.layout)
3319            self.assertEqual(pt_tensor.shape, dense.shape)
3320            compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout]
3321            for batch_index in np.ndindex(batch_shape):
3322                pt_matrix = pt_tensor[batch_index]
3323                dense_matrix = dense[batch_index]
3324                dense_dim = pt_matrix.dim() - 2
3325                dense_matrix_pt = dense_matrix.to_sparse(layout=layout,
3326                                                         blocksize=blocksize or None,
3327                                                         dense_dim=dense_dim)
3328                # sanity check, selecting batch of to_<layout> and dense[batch].to_<layout> should give the same result
3329                self.assertEqual(pt_matrix, dense_matrix_pt)
3330                check_batch(pt_matrix, dense_matrix, blocksize, **kwargs)
3331
3332        def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
3333            shape = batch_shape + sparse_shape + hybrid_shape
3334            n_batch_dim = len(batch_shape)
3335            n_hybrid_dim = len(hybrid_shape)
3336            # generate a dense tensor
3337            dense = make_tensor(shape, dtype=torch.float, device=device)
3338
3339            # introduce some sparsty, mask is sparse shape, element applies to entire dense sub-tensor (hybrid) and is
3340            # applied to each batch
3341            mask = make_tensor(sparse_shape, dtype=torch.bool, device=device)
3342            # manually expand to match hybrid shape
3343            if hybrid:
3344                mask = mask.view(sparse_shape + tuple(1 for _ in range(n_hybrid_dim)))
3345                mask = mask.expand(sparse_shape + hybrid_shape)
3346
3347            # mask will broadcast over the batch dims if present
3348
3349            return dense * mask
3350
3351        # note: order is important here, the hybrid-ness decides the inner content check which is used to build the
3352        # batched checker (if needed)
3353        check_content = _check_against_scipy_matrix
3354        if hybrid:
3355            check_content = _check_hybrid_matrix
3356        if batched:
3357            check_content = functools.partial(_check_batched, check_batch=check_content)
3358
3359        sparse_sizes = [(6, 10), (0, 10), (6, 0), (0, 0)]
3360        blocksizes = [(2, 2), (1, 1), (1, 2)] if layout in blocked_layouts else [()]
3361        batch_sizes = [(3,), (1, 3), (2, 1, 3)] if batched else [()]
3362        hybrid_sizes = [(4, ), (2, 2)] if hybrid else [()]
3363
3364        # general cases, always run
3365        for sparse_shape, blocksize, batch_shape, hybrid_shape in itertools.product(
3366                sparse_sizes, blocksizes, batch_sizes, hybrid_sizes):
3367            dense = _generate_subject(sparse_shape, batch_shape, hybrid_shape)
3368            sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None, dense_dim=len(hybrid_shape))
3369            check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)
3370            dense_back = sparse.to_dense()
3371            self.assertEqual(dense, dense_back)
3372
3373        # special cases for batched tensors
3374        if batched:
3375            # batched sparse tensors need only have the same number of non-zeros in each batch not nessesarily the
3376            # same sparsity pattern in each batch
3377            sparse_shape = sparse_sizes[0]
3378            hybrid_shape = hybrid_sizes[0]
3379            batch_shape = batch_sizes[0]
3380            shape = batch_shape + sparse_shape + hybrid_shape
3381            dense = make_tensor(shape, dtype=torch.float, device=device)
3382            blocksize = blocksizes[0]
3383            # number of elements/blocks in each batch (total not nnz)
3384            batch_mask_shape = sparse_shape
3385            if layout in blocked_layouts:
3386                # if we are blocked the mask is genereated for the block valued elemetns
3387                batch_mask_shape = sparse_shape[0] // blocksize[0], sparse_shape[1] // blocksize[1]
3388
3389            # random bool vector w/ length equal to max possible nnz for the sparse_shape
3390            mask_source = make_tensor(batch_mask_shape, dtype=torch.bool, device=device).flatten()
3391            n_batch = functools.reduce(operator.mul, batch_shape, 1)
3392
3393            # stack random permutations of the source for each batch
3394            mask = torch.stack([mask_source[torch.randperm(mask_source.numel())]
3395                               for _ in range(n_batch)], dim=0).reshape(batch_shape + batch_mask_shape)
3396            if layout in blocked_layouts:
3397                # for blocked we need to do a bit of extra work to expand the mask from blocked-space to element-space
3398                mask_shape = mask.shape
3399                mask = mask.view(mask_shape + (1, 1))
3400                mask = mask.expand(mask_shape + blocksize)
3401                mask = mask.transpose(-3, -2)
3402                mask = mask.flatten(-4, -3).flatten(-2, -1)
3403            mask_shape = mask.shape
3404            mask = mask.view(mask_shape + (1,) * len(hybrid_shape))
3405            mask = mask.expand(mask_shape + hybrid_shape)
3406            dense = dense * mask
3407            sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None, dense_dim=len(hybrid_shape))
3408            check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)
3409
3410            dense_back = sparse.to_dense()
3411            self.assertEqual(dense, dense_back)
3412
3413            # if batches have different nnz we expect the conversion to throw
3414            mask_0 = mask[0]
3415            mask_1 = mask[0].clone().fill_(True)
3416            mask_2 = mask[0].clone().fill_(False)
3417            mask_true = mask_source.clone().fill_(True)
3418            mask_false = mask_source.clone().fill_(False)
3419            mask = torch.stack([(mask_0, mask_1, mask_2)[i % 3] for i in range(n_batch)], dim=0).reshape(batch_shape + mask_0.shape)
3420            dense = make_tensor(shape, dtype=torch.float, device=device)
3421            dense = dense * mask
3422            msg = "Expect the same number of specified elements per batch."
3423            with self.assertRaisesRegex(RuntimeError, msg):
3424                dense.to_sparse(layout=layout, blocksize=blocksize or None)
3425
3426            # Should throw if there is a zero in the batch size
3427            dense = make_tensor((0,) + shape, dtype=torch.float, device=device)
3428            layout_code = str(layout).split("_")[-1]
3429            msg = f"to_sparse_{layout_code}: Expected product of batch dimensions to be non-zero."
3430            with self.assertRaisesRegex(RuntimeError, msg):
3431                dense.to_sparse(layout=layout, blocksize=blocksize or None)
3432
3433    @skipMeta
3434    @all_sparse_compressed_layouts()
3435    @coalescedonoff
3436    @dtypes(torch.double)
3437    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
3438    def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout):
3439        """
3440        This test tests conversion from COO to CSR and CSC and CSC to CSR and CSC
3441        by comparing to SciPy's implementation.
3442
3443        Here we test only those conversion combinations that SciPy
3444        supports to ensure that PyTorch conversions are in the same
3445        page with SciPy.  Independent from SciPy, all conversion
3446        combinations are tested in TestSparseAny.test_to_sparse.
3447        """
3448
3449        blocksize_kw = {}
3450        if layout in (torch.sparse_bsc, torch.sparse_bsr):
3451            blocksize_kw['blocksize'] = (2, 2)
3452            # block modes don't support 0 width/height
3453            shapes = [(6, 10)]
3454        elif layout in (torch.sparse_csc, torch.sparse_csr):
3455            shapes = [(0, 10), (6, 0), (6, 10), (0, 0)]
3456        else:
3457            raise NotImplementedError("unhandled layout")
3458
3459        if layout in (torch.sparse_bsc, torch.sparse_csc):
3460            compressed_indices_mth = torch.Tensor.ccol_indices
3461            plain_indices_mth = torch.Tensor.row_indices
3462        elif layout in (torch.sparse_bsr, torch.sparse_csr):
3463            compressed_indices_mth = torch.Tensor.crow_indices
3464            plain_indices_mth = torch.Tensor.col_indices
3465        else:
3466            raise NotImplementedError("unhandled layout")
3467
3468        for shape in shapes:
3469            sparse_dim = 2
3470            nnz = shape[0] * shape[1] // 2
3471            sparse, _, _ = self.genSparseTensor(shape, sparse_dim, nnz, coalesced, device, dtype)
3472            sp_matrix = self._construct_sp_matrix(sparse, layout)
3473            pt_matrix = sparse.to_sparse(layout=layout, **blocksize_kw)
3474
3475            self.assertEqual(layout, pt_matrix.layout)
3476            self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3477            self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3478            self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3479            self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3480
3481            sparse_csc = sparse.to_sparse_csc()
3482            sp_matrix = self._construct_sp_matrix(sparse_csc, layout)
3483            pt_matrix = sparse_csc.to_sparse(layout=layout, **blocksize_kw)
3484
3485            self.assertEqual(layout, pt_matrix.layout)
3486            self.assertEqual(sp_matrix.shape, pt_matrix.shape)
3487            self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix))
3488            self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix))
3489            self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values())
3490
3491    @unittest.skipIf(not TEST_CUDA_CUDSS, "The test requires cudss")
3492    @dtypes(*floating_types())
3493    def test_linalg_solve_sparse_csr_cusolver(self, device, dtype):
3494        # https://github.com/krshrimali/pytorch/blob/f5ee21dd87a7c5e67ba03bfd77ea22246cabdf0b/test/test_sparse_csr.py
3495
3496        try:
3497            spd = torch.rand(4, 3)
3498            A = spd.T @ spd
3499            b = torch.rand(3).cuda()
3500            A = A.to_sparse_csr().cuda()
3501            x = torch.sparse.spsolve(A, b)
3502        except RuntimeError as e:
3503            if "Calling linear solver with sparse tensors requires compiling " in str(e):
3504                self.skipTest("PyTorch was not built with cuDSS support")
3505
3506        samples = sample_inputs_linalg_solve(None, device, dtype)
3507
3508        for sample in samples:
3509            if sample.input.ndim != 2:
3510                continue
3511
3512            out = torch.zeros(sample.args[0].size(), dtype=dtype, device=device)
3513            if sample.args[0].ndim != 1 and sample.args[0].size(-1) != 1:
3514                with self.assertRaisesRegex(RuntimeError, "b must be a 1D tensor"):
3515                    out = torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs)
3516                break
3517            if not sample.args[0].numel():
3518                with self.assertRaisesRegex(RuntimeError,
3519                                            "Expected non-empty other tensor, but found empty tensor"):
3520                    torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out)
3521                break
3522
3523            expect = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
3524            sample.input = sample.input.to_sparse_csr()
3525            if sample.args[0].ndim != 1 and sample.args[0].size(-1) == 1:
3526                expect = expect.squeeze(-1)
3527                sample.args = (sample.args[0].squeeze(-1), )
3528            out = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs)
3529            self.assertEqual(expect, out)
3530
3531
3532def skipIfNoTriton(cls):
3533    from torch.utils._triton import has_triton
3534
3535    # no-op if triton is present
3536    if has_triton():
3537        return cls
3538    else:
3539
3540        @functools.wraps(cls, updated=())
3541        class skipped_cls(cls):
3542            def setUp(self):
3543                self.skipTest("Triton is not available.")
3544
3545        return skipped_cls
3546
3547@skipIfNoTriton
3548class TestSparseCompressedTritonKernels(TestCase):
3549
3550    def _to_block_triangular_inplace(self, d, row_block, col_block):
3551        """
3552        This function modifies `d` to become (upper/lower) block-triangular in-place.
3553        It is assumed that `d.shape[-2]` is divisible by `row_block` and
3554        `d.shape[-1]` is divisible by `col_block`.
3555        """
3556
3557        from torch.sparse._triton_ops import tile_to_blocksize
3558
3559        m, n = d.shape[-2:]
3560        d_tiled = tile_to_blocksize(d, (row_block, col_block))
3561        d_tiled = d_tiled.moveaxis(-4, -1).moveaxis(-4, -1)
3562        if m // row_block > n // col_block:
3563            d_tiled.tril_()
3564        else:
3565            d_tiled.triu_()
3566
3567        return d
3568
3569    @onlyCUDA
3570    @skipIfRocm(msg="test is too slow on ROCm stack")
3571    @dtypes(torch.half, torch.bfloat16, torch.float)
3572    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3573    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3574    def test_triton_bsr_softmax(self, device, dtype):
3575        from functools import partial
3576        from torch.sparse._triton_ops import bsr_softmax
3577
3578        tensor = partial(make_tensor, device=device, dtype=dtype, low=1.0, high=3.0)
3579
3580        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3581        batches = [(), (2,), (2, 2)]
3582        size = [6, 12, 0]
3583        block_size = [2, 3]
3584
3585        # General correctness
3586        for row_block, col_block, b, m, n in itertools.product(block_size, block_size, batches, size, size):
3587            input = tensor(b + (m, n))
3588            input.diagonal(dim1=-2, dim2=-1).fill_(m * n)
3589            input = self._to_block_triangular_inplace(input, row_block, col_block)
3590
3591            bsr = input.to_sparse_bsr((row_block, col_block))
3592            coo = input.to_sparse().to(torch.float)
3593
3594            res_tri = bsr_softmax(bsr)
3595            res_coo = torch.sparse.softmax(coo, -1)
3596            self.assertEqual(res_tri, res_coo.to(input.dtype))
3597
3598        # Test long rows which exceed Triton's max numel limit set to 2 ** 17
3599        input = tensor(b + (1, 150000))
3600        bsr = input.to_sparse_bsr(1)
3601        self.assertEqual(input.softmax(-1), bsr_softmax(bsr))
3602
3603    @parametrize("block_size", [16, 32, 64])
3604    @parametrize("index_dtype", [torch.int32, torch.int64])
3605    @onlyCUDA
3606    @dtypes(torch.half, torch.bfloat16, torch.float)
3607    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3608    @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(),
3609                     "Skipped for deploy and internal with remote GPUs")
3610    def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
3611        from functools import partial
3612        from torch.sparse._triton_ops import bsr_dense_mm
3613
3614        def kernel_impl(*args, **kwargs):
3615            return bsr_dense_mm(*args, skip_checks=True, **kwargs)
3616
3617        kernel = torch._TritonLibrary.registerOp(
3618            "_triton_bsr_dense_mm_out",
3619            "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)",
3620            kernel_impl,
3621            "SparseCsrCUDA"
3622        )
3623
3624        # kernel != kernel_impl means dispatch was already registered.
3625        # This is exactly what we need!
3626        self.assertTrue(kernel is not kernel_impl)
3627
3628        # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3629        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3630
3631        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3632        batches = [(), (2,), (2, 2)]
3633        size = [128, 256, 0]
3634
3635        # Whether to make inputs orthogonal so that the product is zero
3636        make_orthogonal = [True, False]
3637
3638        for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
3639            bsr = tensor(bs + (m, k))
3640            # NOTE: do not get confused, it will be transposed
3641            dense = tensor(bd + (n, k))
3642
3643            if is_ortho:
3644                bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
3645                dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)
3646
3647            bsr = bsr.to_sparse_bsr(block_size)
3648
3649            if bsr.dim() == 2 and dtype != torch.float:
3650                # Test against linear to check dispatch
3651                # which takes place for torch.half and torch.bfloat16.
3652                res_dense = torch.nn.functional.linear(dense, bsr.to_dense())
3653                res_tri_out = torch.empty_like(res_dense)
3654                res_tri = torch.nn.functional.linear(dense, bsr, out=res_tri_out)
3655
3656                # Check dispatch worked with non-trivial outputs
3657                if m > 0 and n > 0 and k > 0:
3658                    self.assertTrue(kernel.kernel_invoked)
3659                    kernel.kernel_invoked = False
3660            else:
3661                # Otherwise check correctness against bmm
3662                # since nn.linear does not support bsr.dim() > 2.
3663                res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
3664                res_tri_out = torch.empty_like(res_dense)
3665                res_tri = kernel(bsr, dense.transpose(-2, -1), out=res_tri_out)
3666
3667            self.assertTrue(res_tri is res_tri_out)
3668            self.assertEqual(res_tri, res_dense)
3669
3670            res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
3671            # check whether bsr_dense_mm handles different grid sizes
3672            # None means max possible grid size which is CUDA-dependent.
3673            grid_size = (None, 2, 4)
3674            grid_gen = itertools.product(grid_size, repeat=3)
3675            for grid in grid_gen:
3676                res_tri = torch.sparse._triton_ops.bsr_dense_mm(
3677                    bsr,
3678                    dense.transpose(-2, -1),
3679                    max_grid=grid,
3680                )
3681                self.assertEqual(res_tri, res_dense)
3682
3683    @onlyCUDA
3684    @dtypes(torch.half)
3685    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(),
3686                     "Skipped for deploy and internal with remote GPUs")
3687    def test_triton_bsr_dense_bmm_error_messages(self, device, dtype):
3688        from torch.sparse._triton_ops import bsr_dense_mm
3689
3690        rhs = torch.rand(32, 32, dtype=dtype, device=device)
3691        lhs = rhs.to_sparse_bsr(16)
3692        with self.assertRaisesRegex(ValueError, "only BSR sparse format is supported"):
3693            bsr_dense_mm(lhs.to_sparse_bsc(16), rhs)
3694        with self.assertRaisesRegex(ValueError, "on the same GPU device"):
3695            bsr_dense_mm(lhs, rhs.cpu())
3696        if torch.cuda.device_count() > 1:
3697            with self.assertRaisesRegex(ValueError, "on the same GPU device"):
3698                bsr_dense_mm(lhs.to("cuda:0"), rhs.to("cuda:1"))
3699        with self.assertRaisesRegex(ValueError, "all inputs are expected to be of the same dtype"):
3700            bsr_dense_mm(lhs, rhs.to(torch.float))
3701        with self.assertRaisesRegex(ValueError, r"and one of \(half, bfloat16, float32\)"):
3702            bsr_dense_mm(lhs.to(torch.double), rhs.to(torch.double))
3703        with self.assertRaisesRegex(ValueError, "all inputs involved in the matrix product are expected to be at least 2D"):
3704            bsr_dense_mm(lhs, torch.rand(1, dtype=dtype, device=device))
3705        with self.assertRaisesRegex(ValueError,
3706                                    "sizes involved in the matrix product are not compatible for matrix multiplication"):
3707            bsr_dense_mm(lhs, torch.rand(1, 1, dtype=dtype, device=device))
3708        with self.assertRaisesRegex(ValueError,
3709                                    r"dense.size\(-1\) == 15 should be divisible by 16"):
3710            bsr_dense_mm(lhs, torch.rand(32, 15, dtype=dtype, device=device))
3711        # Blocksizes check
3712        for blocksize in (15, 30):
3713            n = blocksize * 2
3714            rhs = torch.rand(n, n, dtype=dtype, device=device)
3715            lhs = rhs.to_sparse_bsr(blocksize)
3716            with self.assertRaisesRegex(ValueError, "should be at least 16 and a power of 2"):
3717                bsr_dense_mm(lhs, rhs)
3718        # out check
3719        rhs = torch.rand(2, 32, 32, dtype=dtype, device=device)
3720        lhs = rhs.to_sparse_bsr(16)
3721        with self.assertRaisesRegex(ValueError, r"`out` argument has wrong shape"):
3722            out = torch.rand(2, 30, 30, dtype=dtype, device=device)
3723            bsr_dense_mm(lhs, rhs, out=out)
3724        with self.assertRaisesRegex(ValueError, r"only row-major/col-major `out`"):
3725            out = torch.rand(32, 32, 2, dtype=dtype, device=device).transpose(0, -1)
3726            bsr_dense_mm(lhs, rhs, out=out)
3727
3728    @parametrize("block_size", [16, 32, 64])
3729    @onlyCUDA
3730    @skipIfRocm
3731    @dtypes(torch.half, torch.bfloat16, torch.float)
3732    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3733    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3734    @precisionOverride({torch.float16: 1e-3})
3735    def test_triton_scaled_dot_product_attention(self, device, dtype, block_size):
3736        from functools import partial
3737        from torch.sparse._triton_ops import _scaled_dot_product_attention
3738
3739        # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3740        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2)
3741
3742        def broadcast_input(*ts):
3743            batch_dims = torch.broadcast_shapes(*(t.shape[:-2] for t in ts))
3744            yield from (torch.broadcast_to(t, batch_dims + t.shape[-2:]) for t in ts)
3745
3746        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3747        batches = [(), (2,), (2, 2)]
3748        size = [128, 256, 0]
3749
3750        for bam, bq, bk, bv, m, n, k in itertools.product(batches, batches, batches, batches, size, size, size):
3751            query = tensor(bq + (m, k))
3752            key = tensor(bk + (n, k))
3753            value = tensor(bv + (n, k))
3754
3755            # We make attn_mask block lower/upper triangular so that BSR and Strided
3756            # function variants are directly comparable.
3757            attn_mask = torch.ones(bam + (m, n), device=device, dtype=torch.bool)
3758            attn_mask = self._to_block_triangular_inplace(attn_mask, block_size, block_size)
3759            attn_mask_bsr = attn_mask.to_sparse_bsr(block_size)
3760
3761            # NOTE: only boolean mask is directly compatible with the Strided version
3762            # without any pre-/post-processing. Hence we test against a boolean mask.
3763            for scale in (None, 1. / 16):
3764                if scale is None and query.size(-1) == 0:
3765                    scale = 1
3766                expected = torch.nn.functional.scaled_dot_product_attention(
3767                    *broadcast_input(query, key, value, attn_mask), scale=scale
3768                )
3769
3770                for mask_dtype in (torch.bool, dtype):
3771                    res = _scaled_dot_product_attention(query, key, value, attn_mask_bsr.to(mask_dtype), scale=scale)
3772                    self.assertEqual(res, expected)
3773
3774
3775    @parametrize("block_size", [16, 32, 64])
3776    @onlyCUDA
3777    @skipIfRocm
3778    @dtypes(torch.half, torch.bfloat16, torch.float)
3779    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3780    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3781    def test_triton_sampled_addmm(self, device, dtype, block_size):
3782        from functools import partial
3783        from torch.sparse._triton_ops import sampled_addmm, broadcast_batch_dims_bsr
3784
3785        # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
3786        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2)
3787
3788        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3789        batches = [(), (2,), (2, 2)]
3790        size = [128, 256, 0]
3791
3792        delta_k = (-3,)
3793        for bi, bm1, bm2, m, n, k, dk in itertools.product(batches, batches, batches, size, size, size, delta_k):
3794            # Test not powers of 2 ks as well.
3795            k = max(0, k + dk)
3796            # Non-trivial sparsity pattern.
3797            # Plus with tril inputs the result is also tril,
3798            # so we can compare BSR and CSR implementations.
3799            input = tensor(bi + (m, n)).tril_()
3800            bsr = input.to_sparse_bsr(block_size)
3801            mat1 = tensor(bm1 + (m, k)).tril_()
3802            mat2 = tensor(bm2 + (k, n)).tril_()
3803
3804            batch_dim = torch.broadcast_shapes(input.shape[:-2], mat1.shape[:-2], mat2.shape[:-2])
3805
3806            csr = input.broadcast_to(batch_dim + input.shape[-2:]).to_sparse_csr().to(torch.float)
3807            mat1csr = mat1.broadcast_to(batch_dim + mat1.shape[-2:]).to(torch.float)
3808            mat2csr = mat2.broadcast_to(batch_dim + mat2.shape[-2:]).to(torch.float)
3809
3810            input_broadcasted_clone = broadcast_batch_dims_bsr(
3811                "test_triton_sampled_addmm",
3812                bsr, mat1, mat2
3813            ).clone()
3814            input_broadcasted_clone = torch.sparse_compressed_tensor(
3815                input_broadcasted_clone.crow_indices(),
3816                input_broadcasted_clone.col_indices(),
3817                # For testing `out=` let's make values to have "weird" strides
3818                # so that if the kernel modifies values to it's needs, the result
3819                # is being compied into out.values.
3820                input_broadcasted_clone.values().transpose(-3, -2).contiguous().transpose(-3, -2),
3821                layout=input_broadcasted_clone.layout,
3822                size=input_broadcasted_clone.shape
3823            )
3824
3825            scalars = (0.0, 2.0)
3826            for alpha, beta, out in itertools.product(scalars, scalars, (None, input_broadcasted_clone)):
3827                res_tri = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, out=out)
3828                if out is not None:
3829                    self.assertTrue(res_tri is out)
3830
3831                batch_broadcasted_shape = torch.broadcast_shapes(*(t.shape[:-2] for t in (input, mat1, mat2)))
3832                self.assertTrue(res_tri.shape == batch_broadcasted_shape + (m, n))
3833
3834                res_csr = torch.sparse.sampled_addmm(csr, mat1csr, mat2csr, alpha=alpha, beta=beta).to(input.dtype)
3835                self.assertEqual(res_tri.to_dense(), res_csr.to_dense())
3836
3837                # Check different grid sizes to make sure that input slicing works
3838                # if this input is larger than the grid.
3839                grid_size = (3, None)
3840                grid_gen = itertools.product(grid_size, repeat=2)
3841                for grid in grid_gen:
3842                    res_tri_grid = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, max_grid=grid)
3843                    self.assertEqual(res_tri, res_tri_grid)
3844
3845    @onlyCUDA
3846    @skipIfRocm
3847    @dtypes(torch.half, torch.bfloat16, torch.float)
3848    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3849    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3850    def test_triton_scatter_mm(self, device, dtype):
3851        from torch.sparse._triton_ops import scatter_mm
3852        from functools import partial
3853        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3854        sizes = [8, 16]
3855        for m, k, n in itertools.product(sizes, sizes, sizes):
3856            blocks = torch.stack([tensor(m, k), tensor(m, k)])
3857            others = torch.stack([tensor(k, n), tensor(k, n)])
3858
3859            expected = torch.stack([blocks[0] @ others[0] + blocks[1] @ others[0],
3860                                    blocks[0] @ others[1],
3861                                    blocks[1] @ others[1]])
3862
3863            indices_data = (
3864                'scatter_mm',
3865                torch.tensor([0, 2, 3, 4], dtype=torch.int32, device=device),
3866                torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32, device=device))
3867
3868            result = scatter_mm(blocks, others, indices_data=indices_data)
3869
3870            self.assertEqual(result, expected)
3871
3872            indices_data = (
3873                'bsr_strided_mm',
3874                torch.tensor([0, 2, 4, 5, 6], dtype=torch.int32, device=device),
3875                torch.tensor([0, n, 2 * n * m, 2 * n * m + n], dtype=torch.int32, device=device),
3876                torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.int32, device=device),
3877                torch.tensor([0, 2 * k * n, n, 2 * k * n + n, 2 * k * n, 2 * k * n + n],
3878                             dtype=torch.int32, device=device),
3879                dict(SPLIT_N=2, is_compressed=False, TILE_M=m, TILE_N=n, GROUP_SIZE=1)
3880            )
3881
3882            for bsize in [(), (2,), (3, 4)]:
3883                other = tensor(*bsize, 2 * k, 2 * n)
3884                expected = torch.cat([
3885                    torch.cat([blocks[1], blocks[0]], dim=1),
3886                    torch.cat([torch.zeros_like(blocks[0]), blocks[1]], dim=1)], dim=0) @ other
3887                result = scatter_mm(blocks, other, indices_data=indices_data)
3888                self.assertEqual(result, expected)
3889
3890    @parametrize("blocksize", [2, '2x3', 16, '16x32', 32, 64])
3891    @onlyCUDA
3892    @dtypes(torch.half, torch.bfloat16, torch.float)
3893    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float)
3894    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
3895    def test_triton_bsr_scatter_mm(self, device, dtype, blocksize):
3896        import triton
3897        from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
3898        from functools import partial
3899        if isinstance(blocksize, str):
3900            blocksize = tuple(map(int, blocksize.split('x')))
3901        else:
3902            blocksize = (blocksize,) * 2
3903        # Note that each value in a non-zero block is in range blocksize * [low^2, high^2).
3904        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
3905
3906        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
3907        batches = [(), (2,), (2, 2)]
3908        sizes = [blocksize[0], 2 * blocksize[0], 4 * blocksize[0]]
3909        sizes_K = [blocksize[1], 2 * blocksize[1]]
3910
3911        for bd, bs, M, K, N, has_zero_row_block in itertools.product(batches, batches[:1], sizes, sizes_K, sizes, (False, True)):
3912            bsr_dense = tensor(bs + (M, K))
3913            if has_zero_row_block:
3914                if M > blocksize[0]:
3915                    bsr_dense[:blocksize[0]].zero_()
3916                else:
3917                    continue
3918            bsr = bsr_dense.to_sparse_bsr(blocksize)
3919            dense = tensor(bd + (K, N))
3920            expected = bsr.to_dense() @ dense
3921
3922            for indices_format in ('bsr_strided_mm', 'bsr_strided_mm_compressed', 'scatter_mm'):
3923                if indices_format in {'bsr_strided_mm', 'bsr_strided_mm_compressed'}:
3924                    SPLIT_N_list = [N]
3925                    while SPLIT_N_list[-1] > 1:
3926                        SPLIT_N_list.append(max(1, SPLIT_N_list[-1] // 2))
3927                else:
3928                    SPLIT_N_list = [1]
3929                for SPLIT_N in SPLIT_N_list:
3930                    indices_data = bsr_scatter_mm_indices_data(
3931                        bsr, dense, indices_format=indices_format, SPLIT_N=SPLIT_N)
3932                    try:
3933                        result = bsr_scatter_mm(bsr, dense, indices_data=indices_data)
3934                    except triton.compiler.OutOfResources:
3935                        # ensure that there was at least one succesful test:
3936                        assert SPLIT_N < SPLIT_N_list[0]
3937                        break
3938
3939                    self.assertEqual(result, expected)
3940        torch.sparse._triton_ops._bsr_scatter_mm_indices_data.cache_clear()
3941
3942    def test_TensorAsKey(self, device):
3943        from torch.sparse._triton_ops import TensorAsKey
3944        assertEqualOptions = dict(exact_dtype=True, exact_device=True, exact_layout=True)
3945
3946        t = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device=device)
3947        key = TensorAsKey(t)
3948        self.assertTrue(key == TensorAsKey(t))
3949        self.assertTrue(key.obj is t)
3950
3951        t2 = t[:]
3952        key2 = TensorAsKey(t2)
3953        self.assertTrue(key == key2)
3954        self.assertEqual(key2.obj, t, **assertEqualOptions)
3955        # deleting object leads to dead key
3956        del t2
3957        self.assertTrue(key2.obj is None)
3958        self.assertTrue(key.obj is t)
3959
3960        # key with different storage offset and shape:
3961        self.assertFalse(key == TensorAsKey(t[1:]))
3962
3963        # key with different strides:
3964        self.assertFalse(key == TensorAsKey(t[::2]))
3965
3966        # when object dies, make sure that key represents a dead
3967        # object as well:
3968        del t
3969        self.assertTrue(key.obj is None)
3970
3971        # Storing a tensor as a dict key:
3972        d = {}
3973        t3 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
3974        key3 = TensorAsKey(t3)
3975        d[key3] = 123
3976        self.assertTrue(d.get(key3) == 123)
3977        t3_ = t3[:]
3978        self.assertTrue(d.get(TensorAsKey(t3_)) == 123)
3979        self.assertTrue(d.get(TensorAsKey(t3.clone())) is None)
3980
3981        d[TensorAsKey(t3_)] = 567
3982        self.assertTrue(d.get(key3) == 567)
3983
3984        # t3 and t3_ reference the same data, so, the key becomes dead
3985        # (that is, its .obj property returns None) until all
3986        # references are deleted:
3987        del t3
3988        self.assertTrue(key3.obj is not None)
3989        self.assertTrue(d.get(key3) == 567)
3990        del t3_
3991        self.assertTrue(key3.obj is None)
3992        self.assertTrue(d.get(key3) == 567)
3993
3994        # Storing a tensor as a dict key and value:
3995        d = {}
3996        t4 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
3997        key4 = TensorAsKey(t4)
3998        d[key4] = (t4, 123)
3999        self.assertEqual(d.get(key4), (t4, 123), **assertEqualOptions)
4000        # when object is deleted, the key represents an alive object
4001        # because the object is referenced by the dict item value:
4002        del t4
4003        self.assertTrue(key4.obj is not None)
4004        # This also means that the life time of the tensor is same as
4005        # the life time of the corresponding dict item:
4006        del d[key4]
4007        self.assertTrue(key4.obj is None)
4008
4009        # Storing a tensor as a dict key and value wrapped with TensorAsKey:
4010        d = {}
4011        t5 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device)
4012        key5 = TensorAsKey(t5)
4013        d[key5] = (key5, 567)
4014        self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
4015        self.assertTrue(key5.obj is not None)
4016        # when object is deleted, it will be dead as the wrapped value
4017        # hold the tensor instance as a weakref:
4018        del t5
4019        self.assertTrue(key5.obj is None)
4020        # but key is still valid:
4021        self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions)
4022
4023    @suppress_warnings
4024    @parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear', '_int_bsr_dense_addmm'])
4025    @parametrize("blocksize", [16, '16x32', 32])
4026    @onlyCUDA
4027    @skipIfRocm
4028    @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
4029    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
4030    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4031    def test_triton_kernel(self, op, device, dtype, blocksize):
4032        from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm
4033        from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta,
4034                                                   optimize_bsr_dense_addmm, dump)
4035
4036        def bsr_dense_linear(input, weights, bias=None):
4037            return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2)
4038
4039        operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear,
4040                         _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
4041
4042        def reference(input, mat1, mat2, beta=1, alpha=1, op=op):
4043            assert mat1.layout is torch.strided
4044            assert mat2.layout is torch.strided
4045            if dtype is torch.int8:
4046                if op == '_int_bsr_dense_addmm':
4047                    return beta * input + alpha * torch._int_mm(mat1, mat2)
4048                # workaround RuntimeError: "addmm_cuda" not implemented for 'Char'
4049                return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8)
4050            return beta * input + alpha * (mat1 @ mat2)
4051
4052        if op == '_int_bsr_dense_addmm':
4053            # _int_bsr_dense_addmm is same as bsr_dense_addmm except
4054            # with int8 inputs, _int_bsr_dense_addmm returns int32
4055            # result. This is covered by operation and reference
4056            # definitions above and all other definitions below are
4057            # identical between _int_bsr_dense_addmm and
4058            # bsr_dense_addmm.
4059            op = 'bsr_dense_addmm'
4060
4061        def nc_copy(t, axes=(-1,)):
4062            """Return a copy of input.
4063
4064            The returned copy will be a non-contiguous tensor.
4065            """
4066            if t.layout is torch.strided:
4067                shape = list(t.shape)
4068                for a in axes:
4069                    shape[a] *= 2
4070                r = torch.empty(shape, dtype=t.dtype, device=t.device)
4071                s = r[tuple(slice(None, None, 2 if t.shape[i] != r.shape[i] else None) for i in range(t.ndim))]
4072                s.copy_(t)
4073                return s
4074            elif t.layout is torch.sparse_bsr:
4075                compressed_indices = t.crow_indices()
4076                plain_indices = t.col_indices()
4077                return torch.sparse_compressed_tensor(compressed_indices, plain_indices, nc_copy(t.values()),
4078                                                      t.shape, layout=t.layout)
4079            else:
4080                raise NotImplementedError(t.layout)
4081
4082        if isinstance(blocksize, str):
4083            BM, BK = tuple(map(int, blocksize.split('x')))
4084        else:
4085            BM, BK = (blocksize,) * 2
4086
4087        if op in {"bsr_dense_linear"} and BM != BK:
4088            # todo: eliminate this skip
4089            self.skipTest(f"{op} does not support non-square blocks")
4090
4091        if op in {"bsr_dense_linear"} and dtype is torch.int8:
4092            # todo: eliminate this skip
4093            self.skipTest(f"{op} does not support int8")
4094
4095        if dtype is torch.int8 and min(BM, BK) < 32:
4096            self.skipTest("triton kernel does not support support int8 blocks smaller than 32")
4097
4098        beta_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[0], bsr_dense_linear=[1])[op]
4099        alpha_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[1], bsr_dense_linear=[1])[op]
4100        sparsity_lst = [0, 0.5, 1]
4101        blocks_per_row_lst = [1, 2]
4102        blocks_per_col_lst = [1, 2]
4103        result_cols_lst = [16, 32, 64]
4104        for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product(
4105                beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst):
4106            M = BM * blocks_per_row
4107            K = BK * blocks_per_col
4108            mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device)
4109            bsr = mat1.to_sparse_bsr((BM, BK))
4110            mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5)
4111            input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5)
4112
4113            if 0 and op == "bsr_dense_addmm":
4114                # Find optimal kernel parameters, the speed-up is
4115                # about 10x for running this test.
4116                #
4117                # Enable this if-block when the test method is
4118                # updated, run the test, and finally, disable the
4119                # if-block.
4120                key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1)
4121                meta = get_meta(op, key, version=(0, dtype, 0.5))
4122                if meta is None:
4123                    optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5)
4124                    meta = get_meta(op, key, version=(0, dtype, 0.5))
4125                    assert meta is not None
4126                    dump()  # this will update torch/sparse/_triton_ops_meta.py
4127
4128            expected = reference(input, mat1, mat2, beta=beta, alpha=alpha)
4129            kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm={},
4130                          bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op]
4131
4132            args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2),
4133                        bsr_dense_linear=(mat2.transpose(-1, -2), bsr))[op]
4134            result = operation(*args, **kwargs)
4135            self.assertEqual(result, expected)
4136
4137            # Test non-contiguous input tensors:
4138            nc_mat2 = nc_copy(mat2)
4139            nc_input = nc_copy(input)
4140            nc_bsr = nc_copy(bsr)
4141
4142            args = dict(bsr_dense_addmm=(input, bsr, nc_mat2), bsr_dense_mm=(bsr, nc_mat2),
4143                        bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
4144            result = operation(*args, **kwargs)
4145            self.assertEqual(result, expected)
4146
4147            # todo: add bsr_dense_linear to the set below (currently,
4148            # nn.linear has unnecessarily restrictive arguments
4149            # checks).
4150            if op in {'bsr_dense_addmm', 'bsr_dense_mm'}:
4151                args = dict(bsr_dense_addmm=(input, nc_bsr, mat2), bsr_dense_mm=(nc_bsr, mat2),
4152                            bsr_dense_linear=(mat2.transpose(-1, -2), nc_bsr))[op]
4153                result = operation(*args, **kwargs)
4154                self.assertEqual(result, expected)
4155
4156            if op in {'bsr_dense_addmm', 'bsr_dense_linear'}:
4157                args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2),
4158                            bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op]
4159                kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha),
4160                              bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op]
4161                result = operation(*args, **kwargs)
4162                self.assertEqual(result, expected)
4163
4164    @parametrize("op", ['bsr_dense_addmm', '_int_bsr_dense_addmm'])
4165    @onlyCUDA
4166    @skipIfRocm
4167    @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8)
4168    @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8)
4169    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4170    def test_triton_tune(self, op, device, dtype):
4171        from torch.sparse._triton_ops import bsr_dense_addmm, _int_bsr_dense_addmm
4172        from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, tune__int_bsr_dense_addmm, get_meta)
4173
4174        operation = dict(bsr_dense_addmm=bsr_dense_addmm, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op]
4175        tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm,
4176                     _int_bsr_dense_addmm=tune__int_bsr_dense_addmm)[op]
4177
4178        if op == '_int_bsr_dense_addmm':
4179            M, K, N = 32, 32, 32
4180            blocksize = (32, 32)
4181        else:
4182            M, K, N = 16, 16, 32
4183            blocksize = (16, 16)
4184        sparsity = 1.0
4185        bsr = create_blocked_tensor(0, M, K, blocksize, sparsity, dtype, device).to_sparse_bsr(blocksize)
4186        sparsity = 1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K)
4187        input = make_tensor(K, N, dtype=dtype, device=device)
4188        dense = make_tensor(K, N, dtype=dtype, device=device)
4189
4190        if op in {'bsr_dense_addmm', '_int_bsr_dense_addmm'}:
4191            args = (input, bsr, dense)
4192
4193            def get_current_meta():
4194                version = (0, dtype, sparsity)
4195                meta_key = (M, K, N, *blocksize, False, True, True)
4196                return get_meta(op, meta_key, version=version, exact=True)
4197        else:
4198            raise NotImplementedError(op)
4199
4200        self.assertEqual(get_current_meta(), None)
4201
4202        meta = tuner(*args, **dict(store=True, verbose=False))
4203        self.assertEqual(get_current_meta(), meta)
4204
4205        expected = operation(*args)
4206        result = operation(*args, **dict(meta=meta))
4207        self.assertEqual(result, expected)
4208
4209    @onlyCUDA
4210    @skipIfRocm
4211    @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton")
4212    def test_triton_bsr_dense_addmm_meta(self, device):
4213        from torch.sparse._triton_ops import bsr_dense_addmm_meta
4214        from torch.sparse._triton_ops_meta import update as update_bsr_dense_addmm_meta
4215
4216        dtype = torch.float32
4217        Ms = Ks = 16
4218        beta = 0.0
4219        alpha = 1.0
4220
4221        def get_meta(M, K, N, sparsity=None):
4222            return bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, dtype=dtype, sparsity=sparsity,
4223                                        _version="test_triton_bsr_dense_addmm_meta")
4224
4225        def update_meta(M, K, N, value, sparsity=0.5):
4226            key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1)
4227            update_bsr_dense_addmm_meta("bsr_dense_addmm", torch.cuda.get_device_name(),
4228                                        ("test_triton_bsr_dense_addmm_meta", dtype, sparsity),
4229                                        key, value)
4230
4231        def get_meta_with_checks(M, K, N, warn_count=0, sparsity=None):
4232            f = io.StringIO()
4233            with redirect_stderr(f):
4234                result = get_meta(M, K, N, sparsity=sparsity)
4235            msg = f.getvalue()
4236            FileCheck().check_count(
4237                str=f"UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M={M} K={K} N={N}",
4238                count=warn_count, exactly=True
4239            ).run(msg)
4240            return result
4241
4242        # Test warn_once when requesting non-existing tuned parameters multiple times
4243        f = io.StringIO()
4244        with redirect_stderr(f):
4245            for i in range(5):
4246                get_meta(16, 16, 16)
4247            for i in range(5):
4248                get_meta(16, 16, 32)
4249
4250        msg = f.getvalue()
4251        FileCheck().check_count(
4252            str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=16", count=1, exactly=True
4253        ).run(msg)
4254        FileCheck().check_count(
4255            str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=32", count=1, exactly=True
4256        ).run(msg)
4257
4258        # Test warn_once when tuned parameters are missing
4259        default_meta = dict(GROUP_SIZE_ROW=4, SPLIT_N=2, num_stages=1, num_warps=4)
4260        self.assertEqual(get_meta_with_checks(32, 32, 32, warn_count=1), default_meta)
4261
4262        # Test (no)warn_once when tuned parameters are available
4263        update_meta(32, 32, 48, (2, 8, 5, 6))
4264        expected_meta = dict(GROUP_SIZE_ROW=2, SPLIT_N=8, num_stages=5, num_warps=6)
4265        self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0), expected_meta)
4266
4267        # Test non-existing tuned parameters with non-default sparsity
4268        # while for default sparsity 0.5 the parameters are available
4269        self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0, sparsity=0.6), expected_meta)
4270
4271        # Test non-existing tuned parameters while there exists
4272        # parameters with consistent N // SPLIT_N ratio:
4273        self.assertEqual(get_meta_with_checks(32, 32, 72, warn_count=0),
4274                         dict(GROUP_SIZE_ROW=2, SPLIT_N=12, num_stages=5, num_warps=6))
4275        # ... or not:
4276        self.assertEqual(get_meta_with_checks(32, 32, 64, warn_count=1),
4277                         dict(GROUP_SIZE_ROW=4, SPLIT_N=4, num_stages=1, num_warps=4))
4278
4279
4280# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
4281instantiate_device_type_tests(TestSparseCSR, globals())
4282instantiate_device_type_tests(TestSparseCompressed, globals())
4283instantiate_device_type_tests(TestSparseCompressedTritonKernels, globals())
4284
4285if __name__ == '__main__':
4286    run_tests()
4287