1# Owner(s): ["module: sparse"] 2 3import torch 4import random 5import io 6import itertools 7import unittest 8import functools 9from contextlib import redirect_stderr 10from torch.testing import make_tensor, FileCheck 11from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, TEST_CUSPARSE_GENERIC 12from torch.testing._internal.common_utils import \ 13 (TEST_WITH_TORCHINDUCTOR, TEST_WITH_ROCM, TEST_CUDA_CUDSS, TEST_SCIPY, TEST_NUMPY, TEST_MKL, IS_WINDOWS, TestCase, 14 run_tests, load_tests, coalescedonoff, parametrize, subtest, skipIfTorchDynamo, skipIfRocm, IS_FBCODE, IS_REMOTE_GPU, 15 suppress_warnings) 16from torch.testing._internal.common_device_type import \ 17 (ops, instantiate_device_type_tests, dtypes, OpDTypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoSparseGeneric, 18 precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse, skipCUDAIfRocmVersionLessThan, 19 largeTensorTest) 20from torch.testing._internal.common_methods_invocations import \ 21 (op_db, sparse_csr_unary_ufuncs, ReductionOpInfo) 22from torch.testing._internal.common_cuda import _get_torch_cuda_version, TEST_CUDA 23from torch.testing._internal.common_dtype import ( 24 floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and, 25 all_types_and_complex, floating_and_complex_types_and) 26from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve 27from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse 28from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED 29import operator 30 31if TEST_SCIPY: 32 import scipy.sparse as sp 33 34if TEST_NUMPY: 35 import numpy as np 36# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 37# sharding on sandcastle. This line silences flake warnings 38load_tests = load_tests 39 40no_mkl_sparse = IS_WINDOWS or not TEST_MKL 41 42def _check_cusparse_triangular_solve_available(): 43 version = _get_torch_cuda_version() 44 # cusparseSpSM was added in 11.3.1 but we don't have access to patch version 45 min_supported_version = (11, 4) 46 return version >= min_supported_version 47 48def _check_cusparse_spgemm_available(): 49 # cusparseSpGEMM was added in 11.0 50 return not TEST_WITH_ROCM 51 52def _check_cusparse_sddmm_available(): 53 if TEST_WITH_ROCM: 54 return True 55 version = _get_torch_cuda_version() 56 # cusparseSDDMM was added in 11.2.1 but we don't have access to patch version 57 min_supported_version = (11, 3) 58 return version >= min_supported_version 59 60_sparse_csr_ops = list(filter(lambda op: op.supports_sparse_csr, op_db)) 61_sparse_compressed_ops = list(filter(lambda op: (op.supports_sparse_csr or op.supports_sparse_csc 62 or op.supports_sparse_bsr or op.supports_sparse_bsc), op_db)) 63binary_functions_with_dense_output = ['mm', 'mv', ] 64binary_ops_with_dense_output = list(filter(lambda op: op.name in binary_functions_with_dense_output, op_db)) 65 66UNARY_EWISE_CSR_ALLOW_AUTOGRAD = [ 67 'abs', 68 'conj_physical', 69 'deg2rad', 70 'neg', 71 'positive', 72 'frac', 73 'nn.functional.relu', 74 'log1p', 75 'rad2deg' 76] 77 78# This should be just an import from test_linalg instead of code duplication 79# but https://github.com/pytorch/pytorch/pull/63511#discussion_r733989701 80def _test_addmm_addmv( 81 test_case, 82 f, 83 t, 84 m, 85 v, 86 *, 87 alpha=None, 88 beta=None, 89 transpose_out=False, 90 layout=torch.strided, 91 mode=None 92): 93 """ 94 Unified test for checking `f(t, m, v, alpha=alpha, beta=beta)` computation, 95 where f is `torch.addmv` or `torch.addmm`. 96 `transpose_out` controls whether the out argument is in column-major order. 97 `layout` controls whether `m` is converted to specified layout or not. 98 Custom behaviour is implemented only for torch.sparse_csr layout. 99 """ 100 dtype = t.dtype 101 numpy_dtype = dtype 102 if dtype in {torch.bfloat16}: 103 numpy_dtype = torch.float 104 if dtype.is_complex: 105 alpha = 0.9 + 0.3j if alpha is None else alpha 106 beta = 0.5 + 0.6j if beta is None else beta 107 else: 108 alpha = 1.2 if alpha is None else alpha 109 beta = 0.8 if beta is None else beta 110 111 def convert_layout(mat): 112 if layout == torch.sparse_csr: 113 return mat.to_sparse_csr() 114 elif layout == torch.sparse_csc: 115 return mat.to_sparse_csc() 116 else: 117 assert mat.layout == layout 118 return mat 119 120 if mode == "all_sparse": 121 res1 = f(*map(convert_layout, (t, m, v)), alpha=alpha, beta=beta) 122 test_case.assertEqual(res1.layout, layout) 123 res1 = res1.to_dense() 124 elif mode == "dense_result": 125 res1 = f(t, convert_layout(m), convert_layout(v), alpha=alpha, beta=beta) 126 else: 127 res1 = f(t, convert_layout(m), v, alpha=alpha, beta=beta) 128 res2 = torch.full_like(res1, float('nan')) 129 if transpose_out: 130 res2 = res2.t().clone(memory_format=torch.contiguous_format).t() 131 f(t, convert_layout(m), v, alpha=alpha, beta=beta, out=res2) 132 res3 = alpha * (m.to(numpy_dtype).cpu().numpy() @ v.to(numpy_dtype).cpu().numpy()) 133 if beta != 0: 134 res3 += (beta * t).to(numpy_dtype).cpu().numpy() 135 res3 = torch.from_numpy(res3).to(dtype) 136 test_case.assertEqual(res1, res2) 137 test_case.assertEqual(res1, res3) 138 139 140class TestSparseCSRSampler(TestCase): 141 142 def test_make_crow_indices(self): 143 # Here we test the correctness of the crow_indices algorithm 144 # and testing it on CPU and with int32 dtype will be 145 # sufficient. 146 device = torch.device('cpu') 147 index_dtype = torch.int32 148 for n_rows in range(1, 10): 149 for n_cols in range(1, 10): 150 for nnz in range(0, n_rows * n_cols + 1): 151 crow_indices = self._make_crow_indices( 152 n_rows, n_cols, nnz, 153 device=device, dtype=index_dtype) 154 self.assertEqual(len(crow_indices), n_rows + 1) 155 counts = crow_indices[1:] - crow_indices[:-1] 156 self.assertEqual(counts.sum(), nnz) 157 self.assertGreaterEqual(counts.min(), 0) 158 self.assertLessEqual(counts.max(), n_cols) 159 160 161def all_sparse_compressed_layouts(test_name='layout'): 162 return parametrize(test_name, [ 163 subtest(torch.sparse_csr, name='SparseCSR'), 164 subtest(torch.sparse_csc, name='SparseCSC'), 165 subtest(torch.sparse_bsr, name='SparseBSR'), 166 subtest(torch.sparse_bsc, name='SparseBSC')]) 167 168 169def sparse_compressed_nonblock_layouts(test_name='layout'): 170 return parametrize(test_name, [ 171 subtest(torch.sparse_csr, name='SparseCSR'), 172 subtest(torch.sparse_csc, name='SparseCSC')]) 173 174 175sparse_compressed_indices_methods = { 176 torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), 177 torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), 178 torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), 179 torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), 180} 181 182 183def batched_nonbatched(test_name='batched'): 184 return parametrize(test_name, [ 185 subtest(True, name="Batched"), 186 subtest(False, name="NonBatched") 187 ]) 188 189 190def hybrid_nonhybrid(test_name='hybrid'): 191 return parametrize(test_name, [ 192 subtest(True, name="Hybrid"), 193 subtest(False, name="NonHybrid") 194 ]) 195 196 197class TestSparseCompressed(TestCase): 198 """Testing sparse compressed (CSR, CSC, BSR, BSC) tensor generic features. 199 """ 200 201 def genTensor(self, size, nnz, *, layout, device=None, dtype=torch.float, index_dtype=torch.int64): 202 if device is None: 203 device = self.device_type 204 return self.genSparseCompressedTensor(size, nnz, device=device, dtype=dtype, index_dtype=index_dtype, layout=layout) 205 206 @all_sparse_compressed_layouts() 207 @onlyCPU 208 def test_layout(self, layout): 209 self.assertIn(str(layout), {'torch.sparse_csr', 'torch.sparse_csc', 'torch.sparse_bsr', 'torch.sparse_bsc'}) 210 self.assertEqual(type(layout), torch.layout) 211 212 @parametrize('shape_and_device_inference', [subtest(False, name='_'), subtest(True, name='shape_and_device_inference')]) 213 @parametrize('use_factory_function', [subtest(False, name='_'), subtest(True, name='factory')]) 214 @parametrize('input_kind', [subtest('tensor', name='from_tensor'), subtest('list', name='from_list')]) 215 @all_sparse_compressed_layouts() 216 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 217 def test_sparse_compressed_constructor(self, layout, device, dtype, 218 use_factory_function, shape_and_device_inference, input_kind): 219 if input_kind == 'list' and shape_and_device_inference: 220 if torch.device(device).type == 'cuda': 221 # list inputs to factory/constructor function without 222 # specifying device will result a sparse compressed tensor 223 # on CPU. So, skip testing against cuda device as unused. 224 self.skipTest("nothing to test") 225 if dtype not in {torch.float32, torch.complex64, torch.int64, torch.bool}: 226 self.skipTest("dtype not supported with list values") 227 228 expected_devices = [torch.device(device)] 229 if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2 and not shape_and_device_inference: 230 expected_devices.append(torch.device('cuda:1')) 231 232 factory_function = { 233 torch.sparse_csr: torch.sparse_csr_tensor, 234 torch.sparse_csc: torch.sparse_csc_tensor, 235 torch.sparse_bsr: torch.sparse_bsr_tensor, 236 torch.sparse_bsc: torch.sparse_bsc_tensor, 237 }[layout] 238 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] 239 if input_kind == 'list': 240 index_dtypes = [torch.int64] 241 else: 242 index_dtypes = [torch.int32, torch.int64] 243 if dtype.is_floating_point or dtype.is_complex: 244 requires_grad_lst = [False, True] 245 else: 246 requires_grad_lst = [False] 247 for index_dtype in index_dtypes: 248 for expected_device in expected_devices: 249 for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs( 250 layout, device=expected_device, dtype=dtype, index_dtype=index_dtype, 251 # skip zero-sized tensors for list inputs: 252 enable_zero_sized=input_kind != 'list', 253 output_tensor=False): 254 size = kwargs['size'] 255 if shape_and_device_inference and 0 in size: 256 # skip shape inference for zero-sized tensor 257 # inputs because (i) the shape determined from 258 # an empty list is ambiguous, and (ii) the 259 # size of the plain dimension defined as 260 # max(plain_indices) is undefined if 261 # plain_indices has no values 262 continue 263 compressed_indices_expect = compressed_indices 264 plain_indices_expect = plain_indices 265 values_expect = values 266 267 if input_kind == 'list': 268 compressed_indices = compressed_indices.tolist() 269 plain_indices = plain_indices.tolist() 270 values = values.tolist() 271 272 for requires_grad in requires_grad_lst: 273 if use_factory_function: 274 if shape_and_device_inference: 275 sparse = factory_function( 276 compressed_indices, plain_indices, values, requires_grad=requires_grad) 277 else: 278 sparse = factory_function( 279 compressed_indices, plain_indices, values, size, 280 dtype=dtype, device=expected_device, requires_grad=requires_grad) 281 else: 282 if shape_and_device_inference: 283 sparse = torch.sparse_compressed_tensor( 284 compressed_indices, plain_indices, values, 285 layout=layout, requires_grad=requires_grad) 286 else: 287 sparse = torch.sparse_compressed_tensor( 288 compressed_indices, plain_indices, values, size, 289 dtype=dtype, layout=layout, device=expected_device, requires_grad=requires_grad) 290 291 self.assertEqual(layout, sparse.layout) 292 self.assertEqual(size, sparse.shape) 293 self.assertEqual(compressed_indices_expect, compressed_indices_mth(sparse)) 294 self.assertEqual(plain_indices_expect, plain_indices_mth(sparse)) 295 self.assertEqual(values_expect, sparse.values()) 296 self.assertEqual(sparse.device, sparse.values().device) 297 self.assertEqual(sparse.device, expected_device) 298 self.assertEqual(sparse.values().requires_grad, requires_grad) 299 self.assertEqual(sparse.requires_grad, requires_grad) 300 self.assertFalse(compressed_indices_mth(sparse).requires_grad) 301 self.assertFalse(plain_indices_mth(sparse).requires_grad) 302 303 @skipMeta 304 @sparse_compressed_nonblock_layouts() 305 @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half)) 306 def test_empty(self, layout, device, dtype): 307 ns = [5, 2, 0] 308 batch_shapes = [(), (2,), (2, 3)] 309 compressed_dim = { 310 torch.sparse_csr: -2, 311 torch.sparse_csc: -1, 312 }[layout] 313 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] 314 for m, n, b in itertools.product(ns, ns, batch_shapes): 315 shape = (*b, m, n) 316 with torch.sparse.check_sparse_tensor_invariants(enable=False): 317 # torch.empty may return invalid sparse compressed tensors 318 result = torch.empty(shape, dtype=dtype, device=device, layout=layout) 319 self.assertEqual(result.shape, shape) 320 self.assertEqual(result.dtype, dtype) 321 self.assertEqual(result.device, torch.device(device)) 322 self.assertEqual(result.layout, layout) 323 self.assertEqual(compressed_indices_mth(result).shape, (*b, shape[compressed_dim] + 1,)) 324 self.assertEqual(plain_indices_mth(result).shape, (*b, 0,)) 325 self.assertEqual(result.values().shape, (*b, 0,)) 326 self.assertEqual(result._nnz(), 0) 327 self.assertEqual(compressed_indices_mth(result).device, torch.device(device)) 328 self.assertEqual(plain_indices_mth(result).device, torch.device(device)) 329 self.assertEqual(result.values().device, torch.device(device)) 330 self.assertEqual(compressed_indices_mth(result).dtype, torch.int64) 331 self.assertEqual(plain_indices_mth(result).dtype, torch.int64) 332 self.assertEqual(result.values().dtype, dtype) 333 334 @skipMeta 335 @sparse_compressed_nonblock_layouts() 336 @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) 337 def test_empty_errors(self, layout, device, dtype): 338 with self.assertRaisesRegex(RuntimeError, 339 "torch.empty: Only batched sparse compressed \\(non-block\\) tensors are supported" 340 ", but got size"): 341 torch.empty((5,), dtype=dtype, device=device, layout=layout) 342 343 @skipMeta 344 @all_sparse_compressed_layouts() 345 @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half)) 346 def test_sparse_compressed_tensor_with_dims(self, layout, device, dtype): 347 348 def get_sparse_compressed_tensor_properties(s): 349 if layout in {torch.sparse_csr, torch.sparse_bsr}: 350 compressed_indices, plain_indices = s.crow_indices(), s.col_indices() 351 else: 352 compressed_indices, plain_indices = s.ccol_indices(), s.row_indices() 353 values = s.values() 354 return dict(shape=s.shape, dtype=s.dtype, device=s.device, nnz=s._nnz(), layout=s.layout, 355 compressed_indices_shape=compressed_indices.shape, 356 compressed_indices_dtype=compressed_indices.dtype, 357 compressed_indices_device=compressed_indices.device, 358 plain_indices_shape=plain_indices.shape, 359 plain_indices_dtype=plain_indices.dtype, 360 plain_indices_device=plain_indices.device, 361 values_shape=values.shape, 362 values_dtype=values.dtype, 363 values_device=values.device) 364 365 for index_dtype in [torch.int32, torch.int64]: 366 for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype): 367 dense_dim = t.dense_dim() 368 sparse_dim = t.sparse_dim() 369 batch_dim = t.ndim - sparse_dim - dense_dim 370 nnz = t.values().shape[batch_dim] 371 if layout in {torch.sparse_bsr, torch.sparse_bsc}: 372 blocksize = t.values().shape[batch_dim + 1: batch_dim + 1 + sparse_dim] 373 else: 374 blocksize = () 375 376 e = torch.ops.aten._sparse_compressed_tensor_with_dims(nnz, dense_dim, t.shape, blocksize, index_dtype, 377 dtype=dtype, layout=layout, device=device) 378 379 e_prop, t_prop = get_sparse_compressed_tensor_properties(e), get_sparse_compressed_tensor_properties(t) 380 for k, v in e_prop.items(): 381 self.assertEqual(v, t_prop[k], lambda msg: f'{msg} when comparing {k}, expected {t_prop[k]}, got {v}') 382 383 @skipMeta 384 @all_sparse_compressed_layouts() 385 @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) 386 def test_clone(self, layout, device, dtype): 387 for sparse in self.generate_simple_inputs( 388 layout, device=device, dtype=dtype, index_dtype=torch.int32): 389 cloned_sparse = sparse.clone() 390 self.assertEqual(sparse, cloned_sparse) 391 392 @all_sparse_compressed_layouts() 393 def test_print(self, layout, device): 394 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] 395 printed = [] 396 for enable_hybrid in [False, True]: 397 # using local patterns for test_print stability 398 patterns = [ 399 # 2 x 3 batch of 3 x 2 tensors, trivial blocksize, non-hybrid/hybrid: 400 ([[[[1, 2, 0], 401 [1, 0, 3]], 402 [[1, 2, 3], 403 [1, 0, 0]], 404 [[1, 0, 0], 405 [1, 2, 3]]], 406 [[[0, 2, 0], 407 [1, 2, 3]], 408 [[1, 0, 3], 409 [1, 2, 0]], 410 [[1, 2, 3], 411 [0, 2, 0]]]], [(2, 1)], [(), (4,)] if enable_hybrid else [()]), 412 # tensor with non-trivial blocksize, non-hybrid/hybrid: 413 ([[0, 1, 0, 2, 0, 2], 414 [0, 1, 0, 0, 2, 0], 415 [3, 3, 3, 0, 0, 0], 416 [0, 0, 0, 0, 0, 0], 417 [0, 5, 0, 6, 6, 6], 418 [5, 0, 5, 6, 6, 6], 419 [0, 0, 0, 0, 8, 8], 420 [7, 7, 7, 0, 8, 8]], [(2, 3)], [(), (4, 2)] if enable_hybrid else [()]), 421 ] 422 for index_dtype in [torch.int32, torch.int64]: 423 for dtype in [torch.float32, torch.float64]: 424 for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs( 425 layout, device=device, dtype=dtype, index_dtype=index_dtype, enable_hybrid=enable_hybrid, 426 enable_non_contiguous_indices=False, enable_non_contiguous_values=False, 427 enable_zero_sized=False, output_tensor=False, patterns=patterns): 428 size = tuple(kwargs['size']) 429 block_ndim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0 430 base_ndim = 2 431 batch_ndim = compressed_indices.dim() - 1 432 dense_ndim = values.dim() - batch_ndim - block_ndim - 1 433 if enable_hybrid and dense_ndim == 0: 434 # non-hybrid cases are covered by the enable_hybrid==False loop 435 continue 436 batchsize = size[:batch_ndim] 437 basesize = size[batch_ndim:batch_ndim + base_ndim] 438 densesize = size[batch_ndim + base_ndim:] 439 assert len(densesize) == dense_ndim 440 printed.append(f"########## {dtype}/{index_dtype}/size={batchsize}+{basesize}+{densesize} ##########") 441 x = torch.sparse_compressed_tensor(compressed_indices, 442 plain_indices, 443 values, size, dtype=dtype, layout=layout, device=device) 444 printed.append("# sparse tensor") 445 printed.append(str(x)) 446 printed.append(f"# _{compressed_indices_mth.__name__}") 447 printed.append(str(compressed_indices_mth(x))) 448 printed.append(f"# _{plain_indices_mth.__name__}") 449 printed.append(str(plain_indices_mth(x))) 450 printed.append("# _values") 451 printed.append(str(x.values())) 452 printed.append('') 453 printed.append('') 454 orig_maxDiff = self.maxDiff 455 self.maxDiff = None 456 try: 457 self.assertExpected('\n'.join(printed)) 458 self.maxDiff = orig_maxDiff 459 except Exception: 460 self.maxDiff = orig_maxDiff 461 raise 462 463 @skipMeta 464 @all_sparse_compressed_layouts() 465 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 466 def test_copy(self, layout, device, dtype): 467 468 def run_test(shape, blocksize, nnz, index_type): 469 a = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device, 470 index_dtype=index_dtype, blocksize=blocksize) 471 b = self.genSparseCompressedTensor(shape, nnz, dtype=dtype, layout=layout, device=device, 472 index_dtype=index_dtype, blocksize=blocksize) 473 474 a.copy_(b) 475 476 self.assertEqual(a, b) 477 478 ns = [(9, 3), (2, 1), (0, 0)] # (number of dimensions, the corresponding block size) 479 batch_shapes = [(), (2,), (2, 3)] 480 for ((m, bm), (n, bn), b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]): 481 blocksize = (bm, bn) if layout in {torch.sparse_bsr, torch.sparse_bsc} else () 482 run_test((*b, m, n), blocksize, 0, index_dtype) 483 run_test((*b, m, n), blocksize, m * n, index_dtype) 484 485 @skipMeta 486 @all_sparse_compressed_layouts() 487 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 488 def test_copy_errors(self, layout, device, dtype): 489 blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else () 490 nnz = 6 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 1 491 shape1 = (2 * 6, 3 * 6) if layout in {torch.sparse_bsr, torch.sparse_bsc} else (2, 3) 492 for index_dtype in [torch.int32, torch.int64]: 493 a = self.genSparseCompressedTensor(shape1, 0, dtype=dtype, layout=layout, device=device, 494 index_dtype=index_dtype, blocksize=blocksize) 495 496 with self.assertRaisesRegex(RuntimeError, 497 "copy of sparse compressed tensors having different layouts is not supported."): 498 a.copy_(torch.empty(a.shape, dtype=dtype, device=device)) 499 500 b = self.genSparseCompressedTensor(shape1, nnz, dtype=dtype, layout=layout, device=device, 501 index_dtype=index_dtype, blocksize=blocksize) 502 assert a._nnz() != b._nnz(), (a._nnz(), b._nnz()) 503 with self.assertRaisesRegex(RuntimeError, 504 "only sparse compressed tensors with the same number of specified elements are supported."): 505 a.copy_(b) 506 507 shape2 = tuple(reversed(shape1)) 508 c = self.genSparseCompressedTensor(shape2, nnz, dtype=dtype, layout=layout, device=device, 509 index_dtype=index_dtype, blocksize=blocksize) 510 with self.assertRaisesRegex( 511 RuntimeError, 512 "expected shapes of self and src to match along dimension"): 513 b.copy_(c) 514 515 if blocksize: 516 blocksize1 = tuple(reversed(blocksize)) 517 d = self.genSparseCompressedTensor(shape1, nnz, dtype=dtype, layout=layout, device=device, 518 index_dtype=index_dtype, blocksize=blocksize1) 519 with self.assertRaisesRegex(RuntimeError, 520 "copy of sparse compressed tensors having different block sizes is not supported"): 521 b.copy_(d) 522 523 def _smallest_divisor(self, n): 524 for i in range(2, int(n ** 0.5) + 1): 525 if n % i == 0: 526 return i 527 return n 528 529 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 530 @all_sparse_compressed_layouts() 531 @ops(_sparse_compressed_ops) 532 @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) 533 def test_consistency(self, layout, device, dtype, op): 534 """Checks that the op on a strided and on a sparse tensors will 535 produce the same results. 536 """ 537 if not op.supports_sparse_layout(layout): 538 self.skipTest(f"{op.name} does not support input with {layout} layout") 539 540 # FIXME: remove in followup once integer support is landed for segment_reduce 541 if (layout == torch.sparse_csr and not dtype.is_floating_point 542 and op.name in ('masked.mean', 'masked.amax', 'masked.amin')): 543 self.skipTest(f"{op.name} does not support input with {layout} layout and {dtype} dtype") 544 545 require_mask = isinstance(op, ReductionOpInfo) and 'masked.' in op.name 546 547 samples = [] 548 for sample in op.sample_inputs(device, dtype): 549 if sample.input.ndim < 2: 550 continue 551 dense_dim = sample.input.ndim - 2 552 blocksize = (tuple(map(self._smallest_divisor, sample.input.shape[:2])) 553 if layout in {torch.sparse_bsr, torch.sparse_bsc} else None) 554 555 def _to_sparse(x): 556 if isinstance(x, torch.Tensor): 557 if blocksize is None: 558 if x.ndim != sample.input.ndim: 559 return x 560 elif x.ndim != sample.input.ndim + 2 or x.shape[-3] % blocksize[0] or x.shape[-2] % blocksize[1]: 561 return x 562 return x.clone().to_sparse(layout=layout, blocksize=blocksize, dense_dim=dense_dim) 563 return x 564 565 sparse_sample = sample.transform(_to_sparse) 566 # Some strided samples (with inf, nan elements) appear to share 567 # storage, so we must clone: 568 sample = sample.transform(lambda x: (x.clone() if isinstance(x, torch.Tensor) else x)) 569 570 if validate_sample_input_sparse(op, sparse_sample, check_validate=False) is not sparse_sample: 571 # that is, the validation returns the sparse sample 572 # wrapped within ErrorInput instance 573 continue 574 samples.append((sample, sparse_sample)) 575 576 # Fail early to prevent silent success with this test 577 if len(samples) == 0: 578 raise ValueError("Expected at least one 2 or higher D tensor in samples.") 579 580 # Re-define atol and rtol for operations that result values 581 # are random (and hence, non-comparable) be we still want to 582 # check the shape, dtype, etc attributes of the results: 583 atol = rtol = None 584 if op.name == 'randn_like': 585 atol = 1e300 586 rtol = 1 587 588 for sample, sparse_sample in samples: 589 expected = op(sample.input, *sample.args, **sample.kwargs) 590 assert torch.is_tensor(expected) 591 output = op(sparse_sample.input, *sparse_sample.args, **sparse_sample.kwargs) 592 assert torch.is_tensor(output) 593 strided_output = output.to_dense() 594 if require_mask and sample.kwargs.get('mask') is not None: 595 output_mask = torch.masked._output_mask(op.op, sample.input, *sample.args, **sample.kwargs) 596 expected.masked_fill_(~output_mask, 0) 597 self.assertEqual(strided_output, expected, atol=atol, rtol=rtol) 598 599 @skipMeta 600 @all_sparse_compressed_layouts() 601 @all_sparse_compressed_layouts('layout2') 602 @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) 603 def test_empty_like(self, layout, layout2, device, dtype): 604 for sparse in self.generate_simple_inputs(layout): 605 if layout == layout2: 606 result = torch.empty_like(sparse, layout=layout2) 607 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[result.layout] 608 torch._validate_sparse_compressed_tensor_args(compressed_indices_mth(result), 609 plain_indices_mth(result), 610 result.values(), 611 result.shape, 612 result.layout) 613 self.assertEqual(sparse.shape, result.shape) 614 else: 615 self.assertRaisesRegex( 616 RuntimeError, 617 "empty_like with different sparse layout is not supported", 618 lambda: torch.empty_like(sparse, layout=layout2) 619 ) 620 621 @skipMeta 622 @all_sparse_compressed_layouts() 623 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 624 def test_validate(self, layout, device, dtype): 625 def make_zero_batched(t): 626 return torch.empty(*((0,) + t.shape), dtype=t.dtype, device=t.device) 627 628 for index_dtype in [torch.int32, torch.int64]: 629 for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs( 630 layout, device=device, dtype=dtype, index_dtype=index_dtype, output_tensor=False): 631 size = kwargs['size'] 632 torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout) 633 634 # check empty batch 635 torch._validate_sparse_compressed_tensor_args( 636 *(make_zero_batched(t) for t in (compressed_indices, plain_indices, values)), 637 (0,) + size, 638 layout 639 ) 640 641 compressed_indices = torch.tensor([0, 0], dtype=index_dtype) 642 plain_indices = torch.tensor([], dtype=index_dtype) 643 torch._validate_compressed_sparse_indices(layout in {torch.sparse_csr, torch.sparse_bsr}, 644 compressed_indices, plain_indices, 1, 1, 0) 645 646 def _generate_invalid_input(self, layout, device): 647 from functools import partial 648 649 def shape(shape, basedim=0): 650 blocksize = (1, 1) 651 if layout is torch.sparse_csc: 652 shape = shape[:basedim] + (shape[basedim + 1], shape[basedim]) + shape[basedim + 2:] 653 elif layout is torch.sparse_bsc: 654 shape = shape[:basedim] + (shape[basedim + 1] * blocksize[1], shape[basedim] * blocksize[0]) + shape[basedim + 2:] 655 elif layout is torch.sparse_bsr: 656 shape = shape[:basedim] + (shape[basedim] * blocksize[0], shape[basedim + 1] * blocksize[1]) + shape[basedim + 2:] 657 return shape 658 659 def values(lst, device=device): 660 if layout in {torch.sparse_bsr, torch.sparse_bsc}: 661 lst = [[[item]] for item in lst] 662 return torch.tensor(lst, device=device) 663 664 tensor = partial(torch.tensor, device=device) 665 values = partial(values, device=device) 666 667 yield ('incontiguous compressed_indices', 668 tensor([0, -1, 2, -1, 4, -1])[::2], 669 tensor([0, 1, 0, 2]), 670 values([1, 2, 3, 4]), 671 shape((2, 3)), 672 'expected compressed_indices to be a contiguous tensor per batch') 673 674 yield ('incontiguous plain_indices', 675 tensor([0, 2, 4]), 676 tensor([0, -1, 1, -1, 0, -1, 2, -1])[::2], 677 values([1, 2, 3, 4]), 678 shape((2, 3)), 679 'expected plain_indices to be a contiguous tensor per batch') 680 681 yield ('0-D compressed_indices', 682 tensor(0), 683 tensor([0, 1, 0, 2]), 684 values([1, 2, 3, 4]), 685 shape((2, 3)), 686 'compressed_indices must have dimensionality >= 1 but got 0') 687 688 yield ('compressed/plain_indices mismatch of dimensionalities', 689 tensor([[0, 2, 4]]), 690 tensor([0, 1, 0, 2]), 691 values([1, 2, 3, 4]), 692 shape((2, 3)), 693 'compressed_indices and plain_indices dimensionalities must be equal but got 2 and 1, respectively') 694 695 if layout in {torch.sparse_csr, torch.sparse_csc}: 696 yield ('indices and values mismatch of dimensionalities', 697 tensor([[0, 2, 4]]), 698 tensor([[0, 1, 0, 2]]), 699 values([1, 2, 3, 4]), 700 shape((2, 3)), 701 r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 0\) but got 1') 702 else: 703 yield ('indices and values mismatch of dimensionalities', 704 tensor([[0, 2, 4]]), 705 tensor([[0, 1, 0, 2]]), 706 values([1, 2, 3, 4]), 707 shape((2, 3)), 708 r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 2\) but got 3') 709 710 yield ('invalid size', 711 tensor([0, 2, 4]), 712 tensor([0, 1, 0, 2]), 713 values([1, 2, 3, 4]), 714 (2,), 715 r'tensor dimensionality must be sum of batch, base, and dense dimensionalities \(=0 \+ 2 \+ 0\) but got 1') 716 717 yield ('invalid batchsize', 718 tensor([[0, 2, 4]]), 719 tensor([[0, 1, 0, 2]]), 720 values([[1, 2, 3, 4]]), 721 shape((2, 2, 3), 1), 722 r'all batch dimensions of compressed_indices \(=\[1\]\), plain_indices \(=\[1\]\), ' 723 r'and values \(=\[1\]\) must be equal to tensor batch dimensions \(=\[2\]\)') 724 725 if layout is torch.sparse_bsr: 726 yield ('invalid blocksize', 727 tensor([0, 2, 4]), 728 tensor([0, 1, 0, 2]), 729 tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]), 730 shape((2, 3)), 731 r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape') 732 733 if layout is torch.sparse_bsc: 734 yield ('invalid blocksize', 735 tensor([0, 2, 4]), 736 tensor([0, 1, 0, 2]), 737 tensor([[[1, 11]], [[2, 22]], [[3, 33]], [[4, 33]]]), 738 shape((3, 2)), 739 r'tensor shape\[1\] \(=3\) must be divisible with blocksize\[1\] \(=2\) as defined by values shape') 740 741 yield ('invalid compressed_indices shape', 742 tensor([0, 2, 3, 4]), 743 tensor([0, 1, 0, 2]), 744 values([1, 2, 3, 4]), 745 shape((2, 3)), 746 r'compressed_indices.shape\[-1\] must be equal to the number of compressed_indices_names \+ 1 \(=3\), but got 4') 747 748 yield ('invalid compressed_indices shape', 749 tensor([0, 2, 4]), 750 tensor([0, 1, 0, 1, 2]), 751 values([1, 2, 3, 4]), 752 shape((2, 3)), 753 r'plain_indices.shape\[-1\] must be equal to nnz \(=4\) as defined by values.shape\[0\], but got 5') 754 755 yield ('compressed/plain_indices mismatch of dtype', 756 tensor([0, 2, 4], dtype=torch.int32), 757 tensor([0, 1, 0, 2], dtype=torch.int64), 758 values([1, 2, 3, 4]), 759 shape((2, 3)), 760 r'compressed_indices and plain_indices must have the same dtype, bot got Int and Long, respectively') 761 762 yield ('invalid compressed/plain_indices dtype', 763 tensor([0, 2, 4], dtype=torch.int16), 764 tensor([0, 1, 0, 2], dtype=torch.int16), 765 values([1, 2, 3, 4]), 766 shape((2, 3)), 767 r'compressed_indices and plain_indices dtype must be Int or Long, but got Short') 768 769 # CUDA kernel asserts are not recoverable, so we skip these for now 770 if torch.device(device).type == 'cpu': 771 yield ('invalid compressed_indices[0]', 772 tensor([1, 2, 4]), 773 tensor([0, 1, 0, 2]), 774 values([1, 2, 3, 4]), 775 shape((2, 3)), 776 r'`compressed_indices\[..., 0\] == 0` is not satisfied.') 777 778 yield ('invalid compressed_indices[0] when nnz == 0', 779 tensor([1, 0], dtype=torch.int64), 780 tensor([], dtype=torch.int64), 781 values([1])[:0], 782 shape((1, 1)), 783 r'`compressed_indices\[..., 0\] == 0` is not satisfied.') 784 785 yield ('invalid compressed_indices[-1]', 786 tensor([0, 2, 5]), 787 tensor([0, 1, 0, 2]), 788 values([1, 2, 3, 4]), 789 shape((2, 3)), 790 r'`compressed_indices\[..., -1\] == nnz` is not satisfied.') 791 792 yield ('invalid compressed_indices[-1] when nnz == 0', 793 tensor([0, 1], dtype=torch.int64), 794 tensor([], dtype=torch.int64), 795 values([1])[:0], 796 shape((1, 1)), 797 r'`compressed_indices\[..., -1\] == nnz` is not satisfied.') 798 799 yield ('invalid compressed_indices.diff(dim=-1)', 800 tensor([0, 0, 4]), 801 tensor([0, 1, 0, 2]), 802 values([1, 2, 3, 4]), 803 shape((2, 3)), 804 r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.') 805 806 yield ('invalid compressed_indices.diff(dim=-1)', 807 tensor([0, 5, 4]), 808 tensor([0, 1, 0, 2]), 809 values([1, 2, 3, 4]), 810 shape((2, 3)), 811 r'0 <= compressed_indices\[..., 1:\] - compressed_indices\[..., :\-1\] <= plain_dim` is not satisfied.') 812 813 yield ('invalid min(plain_indices)', 814 tensor([0, 2, 4]), 815 tensor([0, -1, 0, 3]), 816 values([1, 2, 3, 4]), 817 shape((2, 3)), 818 r'`0 <= plain_indices < plain_dim` is not satisfied.') 819 820 yield ('invalid max(plain_indices)', 821 tensor([0, 2, 4]), 822 tensor([0, 1, 0, 3]), 823 values([1, 2, 3, 4]), 824 shape((2, 3)), 825 r'`0 <= plain_indices < plain_dim` is not satisfied.') 826 827 yield ('non-coalesced', 828 tensor([0, 2, 4]), 829 tensor([1, 0, 0, 2]), 830 values([1, 2, 3, 4]), 831 shape((2, 3)), 832 r'`plain_indices\[..., compressed_indices\[..., i - 1\]:compressed_indices\[..., i\]\] ' 833 'for all i = 1, ..., compressed_dim ' 834 'are sorted and distinct along the last dimension values` is not satisfied.') 835 836 if TEST_CUDA and torch.device(device).type == 'cpu': 837 yield ('indices and values mismatch of device', 838 torch.tensor([0, 2, 4]), 839 torch.tensor([0, 1, 0, 1]), 840 values([1, 2, 3, 4], device='cuda'), 841 shape((2, 3)), 842 r'device of compressed_indices \(=cpu\) must match device of values \(=cuda:0\)') 843 yield ('compressed_indices and values mismatch of device', 844 torch.tensor([0, 2, 4], device='cuda'), 845 torch.tensor([0, 1, 0, 1]), 846 values([1, 2, 3, 4]), 847 shape((2, 3)), 848 r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!') 849 yield ('compressed/plain_indices mismatch of device', 850 torch.tensor([0, 2, 4], device='cuda'), 851 torch.tensor([0, 1, 0, 1]), 852 values([1, 2, 3, 4], device='cuda'), 853 shape((2, 3)), 854 r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!') 855 856 if TEST_CUDA and torch.device(device).type == 'cuda' and torch.cuda.device_count() >= 2: 857 yield ('indices and values mismatch of device index', 858 torch.tensor([0, 2, 4], device='cuda:0'), 859 torch.tensor([0, 1, 0, 1], device='cuda:0'), 860 values([1, 2, 3, 4], device='cuda:1'), 861 shape((2, 3)), 862 r'device of compressed_indices \(=cuda:0\) must match device of values \(=cuda:1\)') 863 yield ('compressed_indices and values mismatch of device index', 864 torch.tensor([0, 2, 4], device='cuda:0'), 865 torch.tensor([0, 1, 0, 1], device='cuda:1'), 866 values([1, 2, 3, 4], device='cuda:0'), 867 shape((2, 3)), 868 r'Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!') 869 870 @skipMeta 871 @all_sparse_compressed_layouts() 872 @parametrize('target', [subtest('validate_sparse_compressed_tensor_args'), 873 subtest('sparse_compressed_tensor'), 874 subtest('sparse_compressed_tensor_no_size')]) 875 def test_invalid_input(self, layout, device, target): 876 for label, compressed_indices, plain_indices, values, size, errmsg in self._generate_invalid_input(layout, device): 877 if layout is torch.sparse_bsr: 878 errmsg = errmsg.replace('compressed_indices_name', 'row block').replace('plain_indices_name', 'column block') 879 elif layout is torch.sparse_bsc: 880 errmsg = errmsg.replace('compressed_indices_name', 'column block').replace('plain_indices_name', 'row block') 881 elif layout is torch.sparse_csr: 882 errmsg = errmsg.replace('compressed_indices_name', 'row').replace('plain_indices_name', 'column') 883 elif layout is torch.sparse_csc: 884 errmsg = errmsg.replace('compressed_indices_name', 'column').replace('plain_indices_name', 'row') 885 if layout in {torch.sparse_csr, torch.sparse_bsr}: 886 errmsg = errmsg.replace('compressed_indices', 'crow_indices') \ 887 .replace('plain_indices', 'col_indices') \ 888 .replace('plain_dim', 'ncols') \ 889 .replace('compressed_dim', 'nrows') 890 else: 891 errmsg = errmsg.replace('compressed_indices', 'ccol_indices') \ 892 .replace('plain_indices', 'row_indices') \ 893 .replace('plain_dim', 'nrows') \ 894 .replace('compressed_dim', 'ncols') 895 896 if target == 'sparse_compressed_tensor_no_size' and label in { 897 'invalid size', 'invalid batchsize', 'invalid compressed_indices shape', 'invalid max(plain_indices)', 898 'invalid blocksize'}: 899 # Skip invalid size input as a valid size is estimated for other inputs 900 continue 901 902 with self.assertRaisesRegex(RuntimeError, errmsg): 903 if target == 'validate_sparse_compressed_tensor_args': 904 torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, values, size, layout) 905 elif target == 'sparse_compressed_tensor': 906 torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout) 907 elif target == 'sparse_compressed_tensor_no_size': 908 torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=layout) 909 else: 910 raise NotImplementedError(target) 911 912 @skipMeta 913 @onlyCPU 914 @largeTensorTest("30GB", "cpu") 915 def test_invalid_input_csr_large(self): 916 rows = 2 ** 31 917 with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in row dimension'): 918 torch.sparse_csr_tensor(torch.arange(rows + 1, dtype=torch.int32) // rows, 919 torch.tensor([0], dtype=torch.int32), 920 torch.tensor([1]), (rows, 1)) 921 torch.sparse_csr_tensor(torch.arange(rows + 1, dtype=torch.int64) // rows, 922 torch.tensor([0], dtype=torch.int64), 923 torch.tensor([1]), (rows, 1)) 924 925 cols = 2 ** 31 926 with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in column dimension'): 927 torch.sparse_csr_tensor(torch.arange(2, dtype=torch.int32), 928 torch.tensor([0], dtype=torch.int32), 929 torch.tensor([1]), (1, cols)) 930 torch.sparse_csr_tensor(torch.arange(2, dtype=torch.int64), 931 torch.tensor([0], dtype=torch.int64), 932 torch.tensor([1]), (1, cols)) 933 934 nnz = 2 ** 31 935 with self.assertRaisesRegex(RuntimeError, '32-bit integer overflow in nnz'): 936 # nnz cannot be stored in int32 crow_indices 937 # but the `crow_indices[..., -1] == nnz`` check happens after the overflow validation 938 # So we can use `nnz - 1` here to avoid `value cannot be converted to type int32 without overflow` 939 # during construction of crow_indices 940 torch.sparse_csr_tensor(torch.tensor([0, nnz // 2, nnz - 1], dtype=torch.int32), 941 torch.arange(nnz // 2, dtype=torch.int32).repeat(2), 942 torch.ones(nnz, dtype=torch.int8), (2, nnz // 2)) 943 torch.sparse_csr_tensor(torch.tensor([0, nnz // 2, nnz], dtype=torch.int64), 944 torch.arange(nnz // 2, dtype=torch.int64).repeat(2), 945 torch.ones(nnz, dtype=torch.int8), (2, nnz // 2)) 946 947 @skipMeta 948 @onlyCPU 949 @all_sparse_compressed_layouts() 950 def test_dim(self, layout): 951 for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(layout, output_tensor=False): 952 size = kwargs['size'] 953 batch_dim = compressed_indices.dim() - 1 954 sparse_dim = 2 955 block_dim = 2 if layout in {torch.sparse_bsr, torch.sparse_bsc} else 0 956 dense_dim = values.dim() - batch_dim - block_dim - 1 957 sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, layout=layout) 958 self.assertEqual(sparse.sparse_dim(), sparse_dim) 959 self.assertEqual(sparse.dense_dim(), dense_dim) 960 961 962 @skipMeta 963 @all_sparse_compressed_layouts() 964 @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) 965 def test_to_dtype(self, layout, device, dtype): 966 # to_dense does not support hybrid inputs 967 for sparse in self.generate_simple_inputs(layout, dtype=dtype, device=device, enable_hybrid=False): 968 for to_dtype in all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16): 969 sparse_to_dtype = sparse.to(to_dtype) 970 dense_to_dtype = sparse.to_dense().to(to_dtype) 971 self.assertEqual(sparse_to_dtype.to_dense(), dense_to_dtype) 972 973 @skipMeta 974 @all_sparse_compressed_layouts() 975 @dtypes(torch.double) 976 def test_pickle(self, layout, dtype, device): 977 import pickle 978 979 for sparse in self.generate_simple_inputs(layout, device=device, dtype=dtype): 980 serialized = pickle.dumps(sparse) 981 sparse_loaded = pickle.loads(serialized) 982 983 self.assertEqual(sparse, sparse_loaded) 984 985 @all_sparse_compressed_layouts() 986 @parametrize("index_dtype", [torch.int32, torch.int64]) 987 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 988 def test_select_copy(self, device, dtype, index_dtype, layout): 989 990 def is_view_of(base, other): 991 # a shameless copy of TestViewOps.is_view_of 992 if ( 993 not other._is_view() or 994 other is base or 995 other._base is not base or 996 base.device != other.device 997 ): 998 return False 999 if base.device.type in ('cpu', 'cuda'): 1000 if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr(): 1001 return False 1002 return True 1003 1004 kwargs = dict(device=device, dtype=dtype, index_dtype=index_dtype) 1005 for sparse, dense in zip(self.generate_simple_inputs(layout, **kwargs), 1006 self.generate_simple_inputs(torch.strided, **kwargs)): 1007 if layout in {torch.sparse_csr, torch.sparse_bsr}: 1008 n_batchdim = sparse.crow_indices().ndim - 1 1009 elif layout in {torch.sparse_csc, torch.sparse_bsc}: 1010 n_batchdim = sparse.ccol_indices().ndim - 1 1011 else: 1012 assert 0 # unreachable 1013 self.assertEqual(sparse, dense) 1014 for dim in range(sparse.ndim): 1015 if sparse.shape[dim] == 0: 1016 with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"): 1017 torch.select_copy(sparse, dim, 0) 1018 with self.assertRaisesRegex(IndexError, "index 0 out of range for tensor of size"): 1019 torch.select_copy(dense, dim, 0) 1020 elif n_batchdim and dim >= n_batchdim and dim < n_batchdim + 2: 1021 with self.assertRaisesRegex( 1022 RuntimeError, 1023 "selecting sparse dimensions is not supported for batched sparse compressed tensors"): 1024 torch.select_copy(sparse, dim, 0) 1025 else: 1026 for index in {0, sparse.shape[dim] // 2, sparse.shape[dim] - 1}: 1027 dense_select = torch.select_copy(dense, dim, index) 1028 sparse_select = torch.select_copy(sparse, dim, index) 1029 self.assertEqual(sparse_select, dense_select) 1030 self.assertFalse(is_view_of(sparse_select.values(), sparse.values())) 1031 1032 1033def _npref_block_addmm_addmv(c, a, b, alpha, beta): 1034 return alpha * (a @ b) + beta * c 1035 1036 1037class TestSparseCSR(TestCase): 1038 1039 def test_csr_stride(self): 1040 a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64) 1041 1042 with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have strides"): 1043 a.stride() 1044 1045 with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have strides"): 1046 a.stride(-1) 1047 1048 def test_csr_storage(self): 1049 a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64) 1050 1051 with self.assertRaisesRegex(RuntimeError, "Cannot access storage of SparseCsrTensorImpl"): 1052 a.storage() 1053 1054 def test_csr_is_contiguous(self): 1055 a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64) 1056 1057 with self.assertRaisesRegex(RuntimeError, "Sparse CSR tensors do not have is_contiguous"): 1058 a.is_contiguous() 1059 1060 @onlyCPU 1061 @largeTensorTest("20GB", "cpu") 1062 def test_csr_nnz(self): 1063 # Tests the limits of the number of specified elements in CSR tensors, see gh-102520. 1064 for nnz in [0, 2**31]: 1065 rows, cols = 1, max(nnz, 1) 1066 crow_indices = torch.tensor([0, nnz], dtype=torch.int64) 1067 col_indices = torch.arange(nnz, dtype=torch.int64) 1068 values = torch.ones(nnz, dtype=torch.int8) 1069 a = torch.sparse_csr_tensor(crow_indices, col_indices, values, (rows, cols)) 1070 self.assertEqual(a._nnz(), nnz) 1071 1072 def test_csr_double_to_sparse_csr(self): 1073 a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64) 1074 a.to_sparse_csr().to_sparse_csr() 1075 1076 @all_sparse_compressed_layouts() 1077 @parametrize("index_dtype", [torch.int32, torch.int64]) 1078 @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) 1079 def test_select(self, device, dtype, index_dtype, layout): 1080 compressed_indices_mth = { 1081 torch.sparse_csr: torch.Tensor.crow_indices, 1082 torch.sparse_bsr: torch.Tensor.crow_indices, 1083 torch.sparse_csc: torch.Tensor.ccol_indices, 1084 torch.sparse_bsc: torch.Tensor.ccol_indices, 1085 }[layout] 1086 1087 plain_indices_mth = { 1088 torch.sparse_csr: torch.Tensor.col_indices, 1089 torch.sparse_bsr: torch.Tensor.col_indices, 1090 torch.sparse_csc: torch.Tensor.row_indices, 1091 torch.sparse_bsc: torch.Tensor.row_indices, 1092 }[layout] 1093 create_tensor_mth = { 1094 torch.sparse_csr: torch.sparse_csr_tensor, 1095 torch.sparse_bsr: torch.sparse_bsr_tensor, 1096 torch.sparse_csc: torch.sparse_csc_tensor, 1097 torch.sparse_bsc: torch.sparse_bsc_tensor, 1098 }[layout] 1099 1100 shape = (2, 3, 6, 10) 1101 nnz = 6 1102 blocksize = (2, 2) if layout in {torch.sparse_bsr, torch.sparse_bsc} else () 1103 sparse = self.genSparseCompressedTensor( 1104 shape, nnz, device=device, layout=layout, dtype=dtype, index_dtype=index_dtype, blocksize=blocksize) 1105 comp_indices = compressed_indices_mth(sparse) 1106 plain_indices = plain_indices_mth(sparse) 1107 values = sparse.values() 1108 1109 # select from batch dimensions 1110 sparse_selected12 = sparse.select(1, 2) 1111 expected_sparse_selected12 = create_tensor_mth(comp_indices.select(1, 2).contiguous(), 1112 plain_indices.select(1, 2).contiguous(), 1113 values.select(1, 2).contiguous(), 1114 size=(2, 6, 10), 1115 dtype=dtype, 1116 device=device) 1117 self.assertEqual(expected_sparse_selected12, sparse_selected12) 1118 1119 # selecting rows/col with batch dims not allowed 1120 sparse_non_batched = sparse[0, 0] 1121 # select from sparse dimensions 1122 for select_args in [(0, 0), (1, 1)]: 1123 sparse_selected = sparse_non_batched.select(*select_args) 1124 dense_selected = sparse_non_batched.to_dense().select(*select_args) 1125 self.assertEqual(dense_selected, sparse_selected) 1126 1127 self.assertEqual(sparse[0, 0, 0, 0], sparse.to_dense()[0, 0, 0, 0]) 1128 # assigning to sparse through indexing is disabled 1129 with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"): 1130 sparse[0, 0, 0, 0] = 99.0 1131 1132 # select from sparse dimensions without removing batch dims 1133 msg = "selecting sparse dimensions is not supported for batched sparse compressed tensors." 1134 with self.assertRaisesRegex(RuntimeError, msg): 1135 sparse.select(-2, 0) 1136 1137 with self.assertRaisesRegex(RuntimeError, msg): 1138 sparse.select(-1, 0) 1139 1140 @skipMeta 1141 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 1142 def test_resize(self, device, dtype): 1143 1144 def numel(tensor): 1145 r = 1 1146 for s in tensor.shape: 1147 r *= s 1148 return r 1149 1150 batch_shapes = [(), (2,), (2, 3)] 1151 for index_dtype, b in zip([torch.int32, torch.int64], batch_shapes): 1152 shape = (*b, 2, 3) 1153 nnz = 6 1154 a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype) 1155 self.assertEqual(a.numel(), numel(a)) 1156 1157 new_shape = (*b, 4, 5) 1158 a.resize_(new_shape) 1159 1160 self.assertEqual(a.shape, new_shape) 1161 # resize to larger shape doesn't add specified elements 1162 self.assertEqual(a._nnz(), nnz) 1163 self.assertEqual(a.numel(), numel(a)) 1164 1165 new_shape = (*b, 1, 5) 1166 a.resize_(new_shape) 1167 1168 self.assertEqual(a.shape, new_shape) 1169 # resize to smaller shape trims specified elements 1170 self.assertEqual(a._nnz(), 5) 1171 self.assertEqual(a.numel(), numel(a)) 1172 1173 # trim batched dimensions 1174 a.resize_(new_shape[-2], new_shape[-1]) 1175 self.assertEqual(a.shape, (new_shape[-2], new_shape[-1])) 1176 self.assertEqual(a._nnz(), 5) 1177 self.assertEqual(a.numel(), numel(a)) 1178 1179 @skipMeta 1180 @dtypes(torch.float, torch.bool) 1181 @all_sparse_compressed_layouts() 1182 def test_resize_as_sparse_compressed(self, device, dtype, layout): 1183 1184 def _check_resize_b_as_a(b, a): 1185 br = b.clone() 1186 br.resize_as_sparse_(a) 1187 1188 # shape is inherited from a 1189 self.assertEqual(a.shape, br.shape) 1190 # other metadata is not affected 1191 self.assertEqual(b.layout, br.layout) 1192 self.assertEqual(b.device, br.device) 1193 self.assertEqual(b.dtype, br.dtype) 1194 1195 def _get_compressed_plain_inds(t): 1196 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[t.layout] 1197 return compressed_indices_mth(t), plain_indices_mth(t) 1198 1199 br_compressed_indices, br_plain_indices = _get_compressed_plain_inds(br) 1200 br_values = br.values() 1201 1202 b_compressed_indices, b_plain_indices = _get_compressed_plain_inds(b) 1203 a_compressed_indices, a_plain_indices = _get_compressed_plain_inds(a) 1204 self.assertEqual(a_plain_indices.shape, br_plain_indices.shape) 1205 self.assertEqual(a_compressed_indices.shape, br_compressed_indices.shape) 1206 # We don't check the content of br_plain_indices and br_compressed_indices 1207 # because it is not well-defined (the content depends on the original 1208 # shape of `b` that `resize_as` ought to discard) nor needed (the 1209 # subsequent operation likely updates the indices and values of `b` anyway). 1210 # the device/dtype of indices should always be unaffected 1211 self.assertEqual(b_plain_indices.dtype, br_plain_indices.dtype) 1212 self.assertEqual(b_plain_indices.device, br_plain_indices.device) 1213 self.assertEqual(b_compressed_indices.dtype, br_compressed_indices.dtype) 1214 self.assertEqual(b_compressed_indices.device, br_compressed_indices.device) 1215 # values are generated empty, shape is updated 1216 self.assertEqual(a.values().shape, br_values.shape) 1217 # the device/dtype of indices should always be unaffected 1218 b_values = b.values() 1219 self.assertEqual(b_values.dtype, br_values.dtype) 1220 self.assertEqual(b_values.device, br_values.device) 1221 # nnz will be picked up from a via new shape of values 1222 self.assertEqual(a._nnz(), br._nnz()) 1223 1224 # post resize the invariants of the layout are respected 1225 torch._validate_sparse_compressed_tensor_args(br_compressed_indices, br_plain_indices, br_values, br.shape, 1226 br.layout) 1227 1228 block_sparse = layout in (torch.sparse_bsr, torch.sparse_bsc) 1229 shape = (2, 1, 6, 4) 1230 nnz = 4 1231 blocksize = (2, 1) if block_sparse else () 1232 for index_dtype in [torch.int32, torch.int64]: 1233 a = self.genSparseCompressedTensor(shape, 1234 layout=layout, 1235 device=device, 1236 index_dtype=index_dtype, 1237 dtype=dtype, 1238 nnz=nnz, 1239 blocksize=blocksize) 1240 1241 # same size, resize should not trigger 1242 b = self.genSparseCompressedTensor(shape, 1243 layout=layout, 1244 device=device, 1245 index_dtype=index_dtype, 1246 dtype=dtype, 1247 nnz=nnz, 1248 blocksize=blocksize) 1249 1250 # This test will not always trigger a resize, if the layouts are the same nothing should happen to b. 1251 # The invariants of the function as checked should still hold 1252 _check_resize_b_as_a(b, a) 1253 1254 # same ndim, but bigger, more nnz, different dtype, different blocksize if blocked 1255 b = self.genSparseCompressedTensor(tuple(s * 2 for s in shape), 1256 layout=layout, 1257 device=device, 1258 dtype=torch.chalf, 1259 index_dtype=torch.int64 if index_dtype == torch.int32 else torch.int32, 1260 nnz=nnz * 2, 1261 blocksize=tuple(2 * bi for bi in blocksize)) 1262 _check_resize_b_as_a(b, a) 1263 1264 # different device, only check on cuda pass as we know we are testing in an environment 1265 # that has multiple devices 1266 1267 # TODO: .cpu() does not seem to work correctly for sparse. Causes a call to `copy_` which 1268 # complains about incompatible nnz between src and self? 1269 if torch.device(device).type == 'cuda' and (layout not in (torch.sparse_bsc, torch.sparse_bsr)): 1270 a_cpu = self.genSparseCompressedTensor(shape, 1271 layout=layout, 1272 device='cpu', 1273 index_dtype=index_dtype, 1274 dtype=dtype, 1275 nnz=nnz, 1276 blocksize=blocksize) 1277 _check_resize_b_as_a(b, a) 1278 1279 # error on a strided 1280 a_strided = a.to_dense() 1281 with self.assertRaisesRegex( 1282 RuntimeError, r'resize_as_sparse_compressed_: src expected sparse compressed tensor layout'): 1283 b.resize_as_sparse_(a_strided) 1284 1285 # error on b strided 1286 b_strided = b.to_dense() 1287 with self.assertRaisesRegex( 1288 RuntimeError, r'resize_as_sparse_compressed_: self expected sparse compressed tensor layout'): 1289 b_strided.resize_as_sparse_(a) 1290 1291 # error if layout does not match, transpose induces layout flip 1292 with self.assertRaisesRegex(RuntimeError, 1293 r"resize_as_sparse_compressed_tensor_: self and src must have the same layout"): 1294 b.transpose(-2, -1).resize_as_sparse_(a) 1295 with self.assertRaisesRegex(RuntimeError, 1296 r"resize_as_sparse_compressed_tensor_: self and src must have the same layout"): 1297 b.resize_as_sparse_(a.transpose(-2, -1)) 1298 1299 @skipMeta 1300 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 1301 def test_resize_errors(self, device, dtype): 1302 for index_dtype in [torch.int32, torch.int64]: 1303 shape = (2, 3) 1304 nnz = 6 1305 a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype) 1306 1307 with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only batched sparse CSR matrices are supported"): 1308 new_shape = (4,) 1309 a.resize_(new_shape) 1310 1311 # resizing of columns to smaller size is not implemented 1312 with self.assertRaisesRegex( 1313 RuntimeError, 1314 "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported.", 1315 ): 1316 new_shape = (2, 2) 1317 a.resize_(new_shape) 1318 1319 @skipIfTorchDynamo() 1320 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 1321 def test_sparse_csr_from_dense(self, device, dtype): 1322 dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], dtype=dtype, device=device) 1323 sparse = dense.to_sparse_csr() 1324 self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64), sparse.crow_indices()) 1325 self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64), sparse.col_indices()) 1326 self.assertEqual(torch.tensor([4, 5, 1], dtype=dtype), sparse.values()) 1327 1328 dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=dtype, device=device) 1329 sparse = dense.to_sparse_csr() 1330 self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64), sparse.crow_indices()) 1331 self.assertEqual(torch.tensor([2, 0], dtype=torch.int64), sparse.col_indices()) 1332 self.assertEqual(torch.tensor([1, 1], dtype=dtype), sparse.values()) 1333 1334 dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]], dtype=dtype, device=device) 1335 sparse = dense.to_sparse_csr() 1336 self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64), sparse.crow_indices()) 1337 self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices()) 1338 self.assertEqual(torch.tensor([2] * 9, dtype=dtype), sparse.values()) 1339 1340 def _test_sparse_compressed_to_dense(self, device, dtype, layout): 1341 compressed_format_str = str(layout)[-3:] 1342 1343 def to_compressed(t): 1344 return getattr(t, f"to_sparse_{compressed_format_str}")() 1345 1346 def compressed_constructor(*input, **kwargs): 1347 constructor = getattr(torch, f"sparse_{compressed_format_str}_tensor") 1348 return constructor(*input, **kwargs) 1349 1350 def get_dense_shape(shape, batch_ndim): 1351 if layout is torch.sparse_csc: 1352 compressed_dims_slice = slice(batch_ndim + 1, batch_ndim - 1, -1) 1353 else: 1354 compressed_dims_slice = slice(batch_ndim, batch_ndim + 2) 1355 return shape[:batch_ndim] + shape[compressed_dims_slice] + shape[batch_ndim + 2:] 1356 1357 def transpose(t, batch_ndim): 1358 if layout is torch.sparse_csc: 1359 return t.transpose(batch_ndim, batch_ndim + 1) 1360 return t 1361 1362 mn = [5, 2, 0] 1363 for (m, n) in itertools.product(mn, mn): 1364 size = (m, n) 1365 dense = make_tensor(size, dtype=dtype, device=device) 1366 sparse = to_compressed(dense) 1367 self.assertEqual(sparse.to_dense(), dense) 1368 1369 batch_shape = (2, 3) 1370 compressed_indices = torch.tensor([0, 3, 5], device=device).repeat(6, 1).reshape(*batch_shape, -1) 1371 plain_indices = torch.tensor([0, 1, 2, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1) 1372 values = torch.tensor([1, 2, 1, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1) 1373 sparse = compressed_constructor(compressed_indices, plain_indices, values, dtype=dtype, device=device) 1374 dense_shape = get_dense_shape(sparse.shape, len(batch_shape)) 1375 dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device).repeat(6, 1).reshape(dense_shape) 1376 self.assertEqual(sparse.to_dense(), transpose(dense, len(batch_shape))) 1377 1378 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 1379 def test_sparse_csr_to_dense(self, device, dtype): 1380 self._test_sparse_compressed_to_dense(device, dtype, torch.sparse_csr) 1381 1382 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 1383 def test_sparse_csc_to_dense(self, device, dtype): 1384 self._test_sparse_compressed_to_dense(device, dtype, torch.sparse_csc) 1385 1386 @skipMeta 1387 @skipCPUIfNoMklSparse 1388 @coalescedonoff 1389 @dtypes(torch.double) 1390 def test_coo_to_csr_convert(self, device, dtype, coalesced): 1391 with self.assertRaisesRegex(RuntimeError, "Input is supposed to be a vector"): 1392 torch._convert_indices_from_coo_to_csr( 1393 torch.randint(100, (5, 5), device=device), 1394 size=100) 1395 1396 size = (5, 5) 1397 sparse_dim = 2 1398 nnz = 10 1399 sparse_coo, _, _ = self.genSparseTensor(size, sparse_dim, nnz, coalesced, device, dtype) 1400 sparse_csr = sparse_coo.to_sparse_csr() 1401 1402 self.assertTrue(sparse_csr.is_sparse_csr) 1403 self.assertEqual(sparse_csr.to_dense(), sparse_coo.to_dense()) 1404 1405 vec = torch.randn((5, 1), dtype=dtype, device=device) 1406 coo_product = sparse_coo.matmul(vec) 1407 csr_product = sparse_csr.matmul(vec) 1408 1409 self.assertEqual(coo_product, csr_product) 1410 1411 vec = torch.randn((100, 1), dtype=dtype, device=device) 1412 index = torch.tensor([ 1413 [1, 0, 35, 14, 39, 6, 71, 66, 40, 27], 1414 [92, 31, 62, 50, 22, 65, 89, 74, 56, 34], 1415 ], dtype=torch.int32) 1416 values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device) 1417 coo = torch.sparse_coo_tensor(index, values, torch.Size([100, 100]), dtype=dtype, device=device) 1418 csr = coo.to_sparse_csr() 1419 1420 self.assertEqual(coo.matmul(vec), csr.matmul(vec)) 1421 1422 col_indices = torch.tensor([ 1423 31, 92, 65, 50, 34, 62, 22, 56, 74, 89 1424 ], dtype=torch.int64, device=device) 1425 self.assertEqual(csr.col_indices(), col_indices) 1426 1427 values = torch.tensor([2, 1, 6, 4, 10, 3, 5, 9, 8, 7], dtype=dtype, device=device) 1428 self.assertEqual(csr.values(), values) 1429 1430 @parametrize("blocksize", [2, 4]) 1431 @dtypes((torch.double, torch.int32), (torch.double, torch.int64)) 1432 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1433 @skipMeta 1434 def test_csr_to_block_csr(self, device, dtypes, blocksize): 1435 for shape in [(24, 24), (12, 24)]: 1436 dtype, index_dtype = dtypes 1437 m, k = shape 1438 nnz = random.randint(0, m * k) 1439 t = self.genSparseCSRTensor((m * blocksize, k * blocksize), nnz, dtype=dtype, 1440 device=device, index_dtype=index_dtype) 1441 st = sp.csr_matrix((t.values().cpu(), t.col_indices().cpu(), t.crow_indices().cpu()), shape=tuple(t.size())) 1442 block_t = t.to_sparse_bsr((blocksize, blocksize)) 1443 self.assertEqual(block_t.values().dim(), 3) 1444 self.assertTrue(block_t.layout == torch.sparse_bsr) 1445 block_st = st.tobsr(blocksize=(blocksize, blocksize)) 1446 block_st.sort_indices() 1447 self.assertEqual(block_t.values().cpu(), block_st.data) 1448 self.assertEqual(block_t.col_indices().cpu(), torch.tensor(block_st.indices).to(index_dtype)) 1449 self.assertEqual(block_t.crow_indices().cpu(), torch.tensor(block_st.indptr).to(index_dtype)) 1450 1451 @dtypes(torch.double) 1452 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1453 def test_csr_to_block_csr_errors(self, device, dtype): 1454 for index_dtype in [torch.int32, torch.int64]: 1455 nnz = 15 1456 t = self.genSparseCSRTensor((16, 16), nnz, dtype=dtype, 1457 device=device, index_dtype=index_dtype) 1458 1459 with self.assertRaisesRegex(RuntimeError, 1460 r"tensor sparse size \(.*,.*\) must be divisible by given blocksize \(.*,.*\)"): 1461 block_t = t.to_sparse_bsr((5, 5)) 1462 1463 # TODO: Support auto generation of device check for sparse tensors 1464 # See: https://github.com/pytorch/pytorch/issues/59058 1465 @onlyCUDA 1466 @dtypes(torch.double) 1467 def test_matmul_device_mismatch(self, device, dtype): 1468 cpu = torch.rand((10, 10)) 1469 cuda = cpu.cuda() 1470 for s, m1, m2 in itertools.product((cpu, cuda), repeat=3): 1471 csr = m1.to_sparse() 1472 if s.device == csr.device == m2.device: 1473 torch.addmm(s, csr, m2) 1474 else: 1475 with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"): 1476 torch.addmm(s, csr, m2) 1477 1478 @skipCPUIfNoMklSparse 1479 @skipCUDAIfNoSparseGeneric 1480 @dtypes(*floating_and_complex_types()) 1481 @dtypesIfCUDA(*floating_and_complex_types_and( 1482 *[torch.half] if SM53OrLater else [], 1483 *[torch.bfloat16] if SM80OrLater else [])) 1484 def test_csr_matvec(self, device, dtype): 1485 1486 if TEST_WITH_ROCM and (dtype == torch.half or dtype == torch.bfloat16): 1487 self.skipTest("ROCm doesn't work with half dtypes correctly.") 1488 1489 side = 100 1490 for index_dtype in [torch.int32, torch.int64]: 1491 csr = self.genSparseCSRTensor((side, side), 1000, device=device, dtype=dtype, index_dtype=index_dtype) 1492 vec = torch.randn(side, dtype=dtype, device=device) 1493 1494 res = csr.matmul(vec) 1495 expected = csr.to_dense().matmul(vec) 1496 1497 self.assertEqual(res, expected) 1498 1499 bad_vec = torch.randn(side + 10, dtype=dtype, device=device) 1500 err_msg = "size mismatch, got" 1501 with self.assertRaisesRegex(RuntimeError, err_msg): 1502 csr.matmul(bad_vec) 1503 1504 @onlyCUDA 1505 # hmm, the test passes ok on CUDA when Rocm is not available: 1506 @skipCUDAIfRocmVersionLessThan((5, 2)) 1507 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 1508 def test_baddbmm(self, device, dtype): 1509 1510 # TODO: disable the invariant checks within torch.baddbmm that 1511 # constructs unconventional csr tensors leading to 1512 # RuntimeError: tensor dimensionality must be sum of batch, 1513 # base, and dense dimensionalities (=0 + 2 + 0) but got 3 1514 # when invariant checking is enabled. When done, undecorate run_test. 1515 @torch.sparse.check_sparse_tensor_invariants(enable=False) 1516 def run_test(c, a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None): 1517 alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random() 1518 beta = complex(random.random(), random.random()) if dtype.is_complex else random.random() 1519 b = b.mH if (op_b and a.shape == b.shape) else b 1520 1521 actual = torch.baddbmm(c, a_batched, b, alpha=alpha, beta=beta) 1522 1523 out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c) 1524 torch.baddbmm(c, a_batched, b, alpha=alpha, beta=beta, out=out) 1525 1526 expected = [torch.addmm(c[i], a, b[i], alpha=alpha, beta=beta) for i in range(c.shape[0])] 1527 expected = torch.stack(expected, 0) 1528 1529 self.assertEqual(actual, out) 1530 self.assertEqual(actual, expected) 1531 1532 for index_dtype in [torch.int32, torch.int64]: 1533 for (m, n, k), batch_size, noncontiguous in zip(itertools.product([2, 5], repeat=3), [1, 3], [True, False]): 1534 nnz = random.randint(0, m * k) 1535 a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) 1536 1537 # a_batched is a regular CSR tensor but with a batch dimension in the shape 1538 a_batched = torch.sparse_csr_tensor( 1539 a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False) 1540 1541 b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous) 1542 c = make_tensor((batch_size, m, n), dtype=dtype, device=device, noncontiguous=noncontiguous) 1543 for op_b, op_out in itertools.product([True, False], repeat=2): 1544 run_test(c, a, a_batched, b, op_b, op_out, dtype=dtype, device=device) 1545 1546 @onlyCUDA 1547 @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported") 1548 @skipCUDAIfNoSparseGeneric 1549 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 1550 def test_bmm(self, device, dtype): 1551 def run_test(a, a_batched, b, op_b=False, op_out=False, *, dtype=None, device=None): 1552 b = b.mH if (op_b and a.shape == b.shape) else b 1553 1554 actual = torch.bmm(a_batched, b) 1555 1556 out = torch.empty_like(actual.mH if op_out and a.shape == b.shape else actual) 1557 torch.bmm(a_batched, b, out=out) 1558 1559 expected = [torch.mm(a, b[i]) for i in range(b.shape[0])] 1560 expected = torch.stack(expected, 0) 1561 1562 self.assertEqual(actual, out) 1563 self.assertEqual(actual, expected) 1564 1565 for index_dtype in [torch.int32, torch.int64]: 1566 for (m, n, k), batch_size, noncontiguous in zip(itertools.product([2, 5], repeat=3), [1, 3], [True, False]): 1567 nnz = random.randint(0, m * k) 1568 a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) 1569 1570 # a_batched is a regular CSR tensor but with a batch 1571 # dimension in the shape. It is unorthodox in PyTorch 1572 # to represent a batch sparse tensor in this way, 1573 # hence checking the tensor invariants is locally 1574 # turned off. 1575 a_batched = torch.sparse_csr_tensor( 1576 a.crow_indices(), a.col_indices(), a.values(), (batch_size, m, k), check_invariants=False) 1577 1578 b = make_tensor((batch_size, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous) 1579 for op_b, op_out in itertools.product([True, False], repeat=2): 1580 run_test(a, a_batched, b, op_b, op_out, dtype=dtype, device=device) 1581 1582 def run_test_block_addmm_addmv(self, 1583 addmv_addmm, 1584 c, 1585 a, 1586 b, 1587 op_b=False, 1588 op_out=False, 1589 *, 1590 dtype=None, 1591 device=None, 1592 ref=_npref_block_addmm_addmv): 1593 alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random() 1594 beta = complex(random.random(), random.random()) if dtype.is_complex else random.random() 1595 b = b.mH if (op_b and a.shape == b.shape) else b 1596 1597 actual = addmv_addmm(c, a, b, alpha=alpha, beta=beta) 1598 1599 out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c) 1600 addmv_addmm(c, a, b, alpha=alpha, beta=beta, out=out) 1601 expected = ref(c, a, b, alpha, beta) 1602 1603 self.assertEqual(actual, out) 1604 self.assertEqual(actual, expected, lambda msg: f"{msg}\na={a}\nc={c}\nb={b}\nalpha={alpha} beta={beta}") 1605 1606 # TODO: block_size 1 is broken 1607 @parametrize("block_size", [2, 3]) 1608 @parametrize("index_dtype", [torch.int32, torch.int64]) 1609 @parametrize("noncontiguous", [True, False]) 1610 @skipCPUIfNoMklSparse 1611 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1612 @skipIfTorchDynamo("raises 'sparse matrix length is ambiguous; use getnnz()'") 1613 @dtypes(*floating_and_complex_types()) 1614 @dtypesIfCUDA(*floating_and_complex_types_and( 1615 *[torch.half] if SM53OrLater else [], 1616 *[torch.bfloat16] if SM80OrLater else [])) 1617 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 1618 torch.float64: 1e-5, torch.complex128: 1e-5, 1619 torch.float16: 1e-3, torch.bfloat16: 1e-3}) 1620 def test_block_addmm(self, device, dtype, index_dtype, block_size, noncontiguous): 1621 1622 def make_transposed_addmm_op(f): 1623 1624 def tt(t): 1625 if isinstance(t, torch.Tensor): 1626 return t.transpose(-2, -1) 1627 else: 1628 # assume numpy/scipy spmatrix 1629 return t.transpose() 1630 1631 @functools.wraps(f) 1632 def wrapper(c, a, b, alpha=None, beta=None, out=None): 1633 if out is not None: 1634 # the ref takes no out kwarg 1635 assert isinstance(out, torch.Tensor) 1636 # transpose inplace to propagate out to checking context 1637 out.transpose_(-2, -1) 1638 return f(tt(c), tt(b), tt(a), alpha=alpha, beta=beta, out=out) 1639 else: 1640 return f(tt(c), tt(b), tt(a), alpha=alpha, beta=beta) 1641 1642 return wrapper 1643 1644 def ref_sp_numpy(c, a, b, alpha=None, beta=None, out=None): 1645 1646 def prep_input(t): 1647 1648 def to_sp_block_compressed(t): 1649 1650 if t.layout is torch.sparse_bsc: 1651 tt = t.transpose(-1, -2) 1652 else: 1653 tt = t 1654 1655 t_sp_bsr = sp.bsr_matrix( 1656 ( 1657 tt.values().cpu().numpy(), 1658 tt.col_indices().cpu().numpy(), 1659 tt.crow_indices().cpu().numpy(), 1660 ), 1661 shape=tt.shape, 1662 ) 1663 1664 if t.layout is torch.sparse_bsc: 1665 return t_sp_bsr.transpose() 1666 else: 1667 return t_sp_bsr 1668 1669 if t.layout is not torch.strided: 1670 return to_sp_block_compressed(t) 1671 else: 1672 return t.cpu().resolve_conj().numpy() 1673 1674 res = _npref_block_addmm_addmv( 1675 *(prep_input(t) for t in (c, a, b)), 1676 alpha, 1677 beta 1678 ) 1679 1680 if out is not None: 1681 out.copy_(res) 1682 return out 1683 else: 1684 return res 1685 1686 def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None): 1687 res = alpha * (a.to_dense().to(torch.float32) @ b.to_dense().to(torch.float32)).to(a.dtype) + beta * c 1688 if out is not None: 1689 out.copy_(res) 1690 return out 1691 else: 1692 return res 1693 1694 if dtype in (torch.half, torch.bfloat16): 1695 ref = ref_half_bfloat16 1696 else: 1697 ref = ref_sp_numpy 1698 1699 for (m, n, k) in itertools.product([2, 5], repeat=3): 1700 nnz = random.randint(0, m * k) 1701 a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) 1702 a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) 1703 a_data = a_data.mT if noncontiguous else a_data 1704 a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), 1705 a_data, (m * block_size, k * block_size), check_invariants=False) 1706 b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous) 1707 c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous) 1708 for op_b, op_out in itertools.product([True, False], repeat=2): 1709 self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device, ref=ref) 1710 self.run_test_block_addmm_addmv(make_transposed_addmm_op(torch.addmm), 1711 c, 1712 a, 1713 b, 1714 op_b, 1715 op_out, 1716 dtype=dtype, 1717 device=device, 1718 ref=make_transposed_addmm_op(ref)) 1719 1720 @parametrize("block_size", [2, 3]) 1721 @parametrize("index_dtype", [torch.int32, torch.int64]) 1722 @parametrize("noncontiguous", [True, False]) 1723 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1724 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 1725 def test_block_addmv(self, device, dtype, index_dtype, block_size, noncontiguous): 1726 # TODO: Explicitly disable block size 1 support 1727 # if (TEST_WITH_ROCM or not TEST_CUSPARSE_GENERIC) and block_size == 1: 1728 # return 1729 def ref_block_addmv(c, a, b, alpha, beta): 1730 return _npref_block_addmm_addmv(c, a.to_dense(), b, alpha, beta) 1731 1732 for (m, k) in itertools.product([2, 5], repeat=2): 1733 nnz = random.randint(0, m * k) 1734 if not noncontiguous: 1735 a = self.genSparseCSRTensor((m * block_size, k * block_size), nnz, 1736 dtype=dtype, device=device, index_dtype=index_dtype) 1737 a = a.to_sparse_bsr((block_size, block_size)) 1738 else: 1739 a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype) 1740 a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) 1741 a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks 1742 a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), 1743 a_data, (m * block_size, k * block_size), check_invariants=False) 1744 b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous) 1745 c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous) 1746 self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device, ref=ref_block_addmv) 1747 1748 @parametrize("matrix_shape", [(3, 3), (5, 7), (11, 9)], name_fn=lambda x: "shape_{}x{}".format(*x)) 1749 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 1750 @onlyCPU 1751 def test_addmv(self, device, dtype, matrix_shape): 1752 mat = torch.randn(matrix_shape, dtype=dtype, device=device) 1753 mat[mat.real < 0] = 0 1754 sparse_mat = mat.to_sparse_csr() 1755 mvec = torch.randn((mat.size(1),), dtype=dtype, device=device) 1756 avec = torch.randn((mat.size(0),), dtype=torch.float64, device=device) 1757 ref_output = torch.addmv(avec, mat, mvec) 1758 output = torch.addmv(avec, sparse_mat, mvec) 1759 self.assertEqual(ref_output, output) 1760 1761 @parametrize("block_size", [2, 3]) 1762 @parametrize("index_dtype", [torch.int32, torch.int64]) 1763 @parametrize("noncontiguous", [True, False]) 1764 @skipCPUIfNoMklSparse 1765 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 1766 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 1767 def test_block_triangular_solve(self, device, dtype, index_dtype, block_size, noncontiguous): 1768 def run_test(a, b, upper, transpose, unitriangular, op_out): 1769 if unitriangular and self.device_type == 'cpu': 1770 # TODO: When unitriangular=True results are not correct on CPU 1771 return 1772 1773 if not upper and self.device_type == 'cpu': 1774 # TODO: When upper=False some generated inputs might crash on CPU 1775 return 1776 1777 actual = torch.triangular_solve(b, a, upper=upper, unitriangular=unitriangular, transpose=transpose) 1778 actual_X = actual.solution 1779 actual_A_clone = actual.cloned_coefficient 1780 self.assertTrue(actual_A_clone.numel() == 0) 1781 if a._nnz() == 0: 1782 self.assertTrue(actual_X.isnan().all()) 1783 return 1784 1785 # TODO: replace with torch method when implemented to_dense() on block sparse tensor 1786 a_bsr = sp.bsr_matrix( 1787 ( 1788 a.values().cpu().numpy(), 1789 a.col_indices().cpu().numpy(), 1790 a.crow_indices().cpu().numpy(), 1791 ), 1792 shape=a.shape, 1793 ) 1794 expected_X, _ = torch.triangular_solve( 1795 b, 1796 torch.tensor(a_bsr.todense(), device=device), 1797 transpose=transpose, 1798 upper=upper, 1799 unitriangular=unitriangular) 1800 1801 if expected_X.isnan().any(): 1802 # TODO: zeros on the diagonal are not handled for CPU path 1803 # there's no way to query this info from MKL 1804 if self.device_type == 'cuda' and not TEST_WITH_ROCM: 1805 self.assertTrue(actual_X.isnan().any() or actual_X.isinf().any()) 1806 return 1807 1808 self.assertEqual(actual_X, expected_X) 1809 1810 out = torch.empty_like(b.mH if op_out and a.shape == b.shape else b) 1811 torch.triangular_solve( 1812 b, a, 1813 upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone) 1814 ) 1815 self.assertEqual(out, actual_X) 1816 self.assertEqual(out, expected_X) 1817 1818 for (m, k) in itertools.product([2, 3], [1, 3]): 1819 nnz = random.randint(0, m * m) 1820 if not noncontiguous: 1821 a = self.genSparseCSRTensor((m * block_size, m * block_size), nnz, 1822 dtype=dtype, device=device, index_dtype=index_dtype) 1823 a = a.to_sparse_bsr((block_size, block_size)) 1824 else: 1825 a = self.genSparseCSRTensor((m, m), nnz, dtype=dtype, device=device, index_dtype=index_dtype) 1826 a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device) 1827 a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks 1828 a = torch.sparse_bsr_tensor(a.crow_indices(), a.col_indices(), 1829 a_data, (m * block_size, m * block_size), check_invariants=False) 1830 b = make_tensor((m * block_size, k), dtype=dtype, device=device, noncontiguous=noncontiguous) 1831 1832 for (upper, unitriangular, transpose, op_out) in itertools.product([True, False], repeat=4): 1833 run_test(a, b, upper, unitriangular, transpose, op_out) 1834 1835 @skipCPUIfNoMklSparse 1836 @unittest.skipIf(TEST_WITH_ROCM, "Only CUDA 11+ is supported") 1837 @dtypes(torch.double) 1838 def test_mm(self, device, dtype): 1839 def test_shape(di, dj, dk, nnz0=None, nnz1=None): 1840 for index_dtype in [torch.int32, torch.int64]: 1841 alpha = random.random() 1842 beta = random.random() 1843 1844 def _test_addmm(t, x, y): 1845 # TODO: addmm doesn't support strided result for sparse inputs. 1846 # res = beta * t + alpha * (x @ y) 1847 res = torch.addmm(t, x, y, beta=beta, alpha=alpha) 1848 expected = torch.addmm(t, x.to_dense(), y.to_dense(), beta=beta, alpha=alpha) 1849 self.assertEqual(res, expected) 1850 1851 res = torch.addmm(t, x, y) 1852 expected = torch.addmm(t, x.to_dense(), y.to_dense()) 1853 self.assertEqual(res, expected) 1854 1855 def _test_mm(x, y): 1856 res = torch.mm(x, y) 1857 expected = torch.mm(x.to_dense(), y.to_dense()) 1858 if x.layout is torch.strided or y.layout is torch.strided: 1859 self.assertEqual(res.layout, torch.strided) 1860 else: 1861 self.assertEqual(res.layout, torch.sparse_csr) 1862 self.assertEqual(res.to_dense(), expected) 1863 1864 def _test(t, x, y): 1865 _test_addmm(t, x, y) 1866 _test_mm(x, y) 1867 1868 if nnz0 is None: 1869 nnz0 = random.randint(di * dk // 2, di * dk) 1870 t = torch.randn(di, dj, dtype=dtype, device=device) 1871 x = self.genSparseCSRTensor((di, dk), nnz0, device=device, dtype=dtype, index_dtype=index_dtype) 1872 y = torch.randn(dk, dj, dtype=dtype, device=device) 1873 _test(t, x, y) 1874 1875 t = torch.randn(di, dj, dtype=dtype, device=device) 1876 x = self.genSparseCSCTensor((di, dk), nnz0, device=device, dtype=dtype, index_dtype=index_dtype) 1877 y = torch.randn(dk, dj, dtype=dtype, device=device) 1878 _test(t, x, y) 1879 1880 if nnz1 is None: 1881 nnz1 = random.randint(dk * dj // 2, dk * dj) 1882 t = torch.randn(di, dj, dtype=dtype, device=device) 1883 x = torch.randn(di, dk, dtype=dtype, device=device) 1884 y = self.genSparseCSRTensor((dk, dj), nnz1, device=device, dtype=dtype, index_dtype=index_dtype) 1885 _test(t, x, y) 1886 1887 t = torch.randn(di, dj, dtype=dtype, device=device) 1888 x = torch.randn(di, dk, dtype=dtype, device=device) 1889 y = self.genSparseCSCTensor((dk, dj), nnz1, device=device, dtype=dtype, index_dtype=index_dtype) 1890 _test(t, x, y) 1891 1892 x_shape, y_shape = x.shape, y.shape 1893 1894 gen_csr_csc = [self.genSparseCSRTensor, self.genSparseCSCTensor] 1895 1896 # Test mm({CSR, CSC}, {CSR, CSC}) 1897 for gen_x, gen_y in itertools.product(gen_csr_csc, gen_csr_csc): 1898 x = gen_x(x_shape, nnz0, device=device, dtype=dtype, index_dtype=index_dtype) 1899 y = gen_y(y_shape, nnz1, device=device, dtype=dtype, index_dtype=index_dtype) 1900 _test_mm(x, y) 1901 1902 def test_empty_inputs(lhs_layout, rhs_layout): 1903 xd = torch.rand(10, 0, device=device, dtype=dtype) 1904 yd = xd.transpose(-2, -1) 1905 zd = torch.rand(0, 0, device=device, dtype=dtype) 1906 1907 xls, yls, zls = (t.to_sparse(layout=lhs_layout) for t in (xd, yd, zd)) 1908 xrs, yrs, zrs = (t.to_sparse(layout=rhs_layout) for t in (xd, yd, zd)) 1909 1910 for ls, rs, ld, rd in [(xls, yrs, xd, yd), (xls, zrs, xd, zd), (zls, yrs, zd, yd), (zls, zrs, zd, zd)]: 1911 res_sparse = ls @ rs 1912 res_dense = ld @ rd 1913 self.assertEqual(res_sparse.to_dense(), res_dense) 1914 1915 def test_orthogonal_inputs(lhs_layout, rhs_layout): 1916 ones = torch.ones(2, 2, device=device, dtype=dtype) 1917 zeros = torch.zeros(2, 2, device=device, dtype=dtype) 1918 x = torch.cat((ones, zeros), -1).to_sparse(layout=lhs_layout) 1919 y = torch.cat((zeros, ones), -2).to_sparse(layout=rhs_layout) 1920 res = x @ y 1921 res_expected = torch.zeros(*res.shape, device=device, dtype=dtype, layout=res.layout) 1922 self.assertEqual(res, res_expected) 1923 1924 for lhs_layout, rhs_layout in itertools.product([torch.sparse_csr, torch.sparse_csc], repeat=2): 1925 test_empty_inputs(lhs_layout, rhs_layout) 1926 test_orthogonal_inputs(lhs_layout, rhs_layout) 1927 1928 for i in [2, 4]: 1929 for j in [2, 4, 7]: 1930 for k in [2, 3, 7]: 1931 test_shape(i, j, k) 1932 test_shape(4, 4, 4, 0, 0) 1933 1934 @skipCPUIfNoMklSparse 1935 @dtypes(*floating_and_complex_types()) 1936 @dtypesIfCUDA(*floating_and_complex_types_and( 1937 *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [], 1938 *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else [])) 1939 @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) 1940 def test_sparse_mm(self, device, dtype): 1941 def test_shape(d1, d2, d3, nnz, transposed, index_dtype): 1942 if transposed: 1943 D = torch.randn(d3, d2, dtype=dtype, device=device).t_() 1944 else: 1945 D = torch.randn(d2, d3, dtype=dtype, device=device) 1946 S = self.genSparseCSRTensor((d1, d2), nnz, device=device, dtype=dtype, index_dtype=index_dtype) 1947 S_dense = S.to_dense() 1948 self.assertEqual(torch.sparse.mm(S, D), torch.mm(S_dense, D)) 1949 1950 for index_dtype in [torch.int32, torch.int64]: 1951 test_shape(7, 8, 9, 20, False, index_dtype) 1952 test_shape(7, 8, 9, 20, True, index_dtype) 1953 1954 @dtypes(*floating_and_complex_types()) 1955 @dtypesIfCUDA(*floating_and_complex_types_and( 1956 *[torch.half] if SM53OrLater and TEST_CUSPARSE_GENERIC else [], 1957 *[torch.bfloat16] if SM80OrLater and TEST_CUSPARSE_GENERIC else [])) 1958 @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2}) 1959 def test_sparse_addmm(self, device, dtype): 1960 def test_shape(m, n, p, nnz, broadcast, index_dtype, alpha_beta=None): 1961 if alpha_beta is None: 1962 alpha = random.random() 1963 beta = random.random() 1964 else: 1965 alpha, beta = alpha_beta 1966 if broadcast: 1967 D1 = make_tensor((), dtype=dtype, device=device) 1968 else: 1969 D1 = make_tensor([n, p], dtype=dtype, device=device) 1970 D2 = make_tensor([m, p], dtype=dtype, device=device) 1971 S = self.genSparseCSRTensor([n, m], nnz, dtype=dtype, device=device, index_dtype=index_dtype) 1972 S_dense = S.to_dense() 1973 Y = torch.sparse.addmm(D1, S, D2, beta=beta, alpha=alpha) 1974 Y_dense = torch.addmm(D1, S_dense, D2, beta=beta, alpha=alpha) 1975 self.assertEqual(Y, Y_dense) 1976 1977 for index_dtype in [torch.int32, torch.int64]: 1978 test_shape(7, 8, 9, 20, False, index_dtype, None) 1979 test_shape(7, 8, 9, 20, True, index_dtype, None) 1980 test_shape(7, 8, 9, 20, False, index_dtype, (1, 0)) 1981 test_shape(7, 8, 9, 20, True, index_dtype, (1, 0)) 1982 test_shape(7, 8, 9, 20, False, index_dtype, (1, 1)) 1983 test_shape(7, 8, 9, 20, True, index_dtype, (1, 1)) 1984 1985 @skipCPUIfNoMklSparse 1986 @dtypes(*floating_and_complex_types()) 1987 @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, 1988 torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 1989 @dtypesIfCUDA(*floating_types_and(torch.complex64, 1990 *[torch.bfloat16] if SM80OrLater else [], 1991 *[torch.half] if SM53OrLater else [], 1992 *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else [])) 1993 @sparse_compressed_nonblock_layouts() 1994 @skipCUDAIf( 1995 not _check_cusparse_spgemm_available(), 1996 "cuSparse Generic API SpGEMM is not available" 1997 ) 1998 def test_addmm_all_sparse_csr(self, device, dtype, layout): 1999 M = torch.randn(10, 25, device=device).to(dtype) 2000 m1 = torch.randn(10, 50, device=device).to(dtype) 2001 m2 = torch.randn(50, 25, device=device).to(dtype) 2002 _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="all_sparse") 2003 2004 # Test 0-strided 2005 M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) 2006 m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50) 2007 m2 = torch.randn(50, 25, device=device).to(dtype) 2008 _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="all_sparse") 2009 2010 # Test beta=0, M=nan 2011 M = torch.full((10, 25), float('nan'), device=device).to(dtype) 2012 m1 = torch.randn(10, 50, device=device).to(dtype) 2013 m2 = torch.randn(50, 25, device=device).to(dtype) 2014 _test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=layout, mode="all_sparse") 2015 2016 # Test transpose 2017 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): 2018 def maybe_transpose(cond, m): 2019 if not cond: 2020 return m 2021 return m.t().clone(memory_format=torch.contiguous_format).t() 2022 2023 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) 2024 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) 2025 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) 2026 _test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=layout, mode="all_sparse") 2027 2028 @onlyCPU 2029 @skipCPUIfNoMklSparse 2030 @dtypes(*floating_and_complex_types()) 2031 @sparse_compressed_nonblock_layouts() 2032 def test_addmm_dense_result(self, device, dtype, layout): 2033 M = torch.randn(10, 25, device=device).to(dtype) 2034 m1 = torch.randn(10, 50, device=device).to(dtype) 2035 m2 = torch.randn(50, 25, device=device).to(dtype) 2036 _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="dense_result") 2037 2038 # Test 0-strided 2039 M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) 2040 m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50) 2041 m2 = torch.randn(50, 25, device=device).to(dtype) 2042 _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=layout, mode="dense_result") 2043 2044 # Test beta=0, M=nan 2045 M = torch.full((10, 25), float('nan'), device=device).to(dtype) 2046 m1 = torch.randn(10, 50, device=device).to(dtype) 2047 m2 = torch.randn(50, 25, device=device).to(dtype) 2048 _test_addmm_addmv(self, torch.addmm, M, m1, m2, beta=0, layout=layout, mode="dense_result") 2049 2050 # Test transpose 2051 for t1, t2, t3, t4 in itertools.product([True, False], repeat=4): 2052 def maybe_transpose(cond, m): 2053 if not cond: 2054 return m 2055 return m.t().clone(memory_format=torch.contiguous_format).t() 2056 2057 M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) 2058 m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) 2059 m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) 2060 _test_addmm_addmv(self, torch.addmm, M, m1, m2, transpose_out=t4, layout=layout, mode="dense_result") 2061 2062 @parametrize("k", [0, 1, 8]) 2063 @parametrize("n", [0, 1, 10]) 2064 @parametrize("m", [0, 1, 25]) 2065 @skipCPUIfNoMklSparse 2066 @dtypes(*floating_and_complex_types()) 2067 @dtypesIfCUDA(*floating_types_and(torch.complex64, 2068 *[torch.bfloat16] if SM80OrLater else [], 2069 *[torch.half] if SM53OrLater else [], 2070 *[torch.complex128] 2071 if CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED 2072 else [])) 2073 @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, 2074 torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) 2075 def test_addmm_sizes_all_sparse_csr(self, device, dtype, m, n, k): 2076 if (TEST_WITH_ROCM and k != 0 and n != 0 and m != 0): 2077 self.skipTest("Skipped on ROCm") 2078 M = torch.randn(n, m, device=device).to(dtype) 2079 m1 = torch.randn(n, k, device=device).to(dtype) 2080 m2 = torch.randn(k, m, device=device).to(dtype) 2081 _test_addmm_addmv(self, torch.addmm, M, m1, m2, layout=torch.sparse_csr, mode="all_sparse") 2082 2083 M = torch.randn(n, m, device=device).to(dtype).to_sparse_csr() 2084 m1 = torch.randn(n, k + 1, device=device).to(dtype).to_sparse_csr() 2085 m2 = torch.randn(k, m, device=device).to(dtype).to_sparse_csr() 2086 self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.addmm(M, m1, m2)) 2087 self.assertRaisesRegex(RuntimeError, f"{n}x{k + 1}.*{k}x{m}", lambda: torch.mm(m1, m2)) 2088 2089 @skipCPUIfNoMklSparse 2090 @dtypes(torch.float) 2091 def test_addmm_errors(self, device, dtype): 2092 # test that the errors are the same for dense and sparse versions 2093 import re 2094 2095 def test1(*, is_sparse): 2096 # shapes must be compatible for matrix multiplication 2097 a = make_tensor((2, 3), dtype=dtype, device=device) 2098 if is_sparse: 2099 a_sparse = a.to_sparse_csr() 2100 return torch.addmm(a, a_sparse, a) 2101 else: 2102 return torch.addmm(a, a, a) 2103 2104 def test2(*, is_sparse): 2105 # mat2 must be a matrix 2106 a = make_tensor((2, 3), dtype=dtype, device=device) 2107 if is_sparse: 2108 a_sparse = a.to_sparse_csr() 2109 return torch.addmm(a, a_sparse, a.unsqueeze(0)) 2110 else: 2111 return torch.addmm(a, a, a.unsqueeze(0)) 2112 2113 def test3(*, is_sparse): 2114 # the first input needs to be 1D or 2D 2115 a = make_tensor((3, 3), dtype=dtype, device=device) 2116 if is_sparse: 2117 a_sparse = a.to_sparse_csr() 2118 return torch.addmm(a.unsqueeze(0), a_sparse, a) 2119 else: 2120 return torch.addmm(a.unsqueeze(0), a, a) 2121 2122 for test in (test1, test2, test3): 2123 try: 2124 test(is_sparse=False) 2125 except RuntimeError as msg: 2126 with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))): 2127 test(is_sparse=True) 2128 2129 @skipCPUIfNoMklSparse 2130 @dtypes(torch.float) 2131 def test_mm_errors(self, device, dtype): 2132 # test that the errors are the same for dense and sparse versions 2133 import re 2134 2135 def test1(*, is_sparse): 2136 # shapes must be compatible for matrix multiplication 2137 a = make_tensor((2, 3), dtype=dtype, device=device) 2138 if is_sparse: 2139 a_sparse = a.to_sparse_csr() 2140 return torch.mm(a_sparse, a) 2141 else: 2142 return torch.mm(a, a) 2143 2144 def test2(*, is_sparse): 2145 # mat2 must be a matrix 2146 a = make_tensor((2, 3), dtype=dtype, device=device) 2147 if is_sparse: 2148 a_sparse = a.to_sparse_csr() 2149 return torch.mm(a_sparse, a.unsqueeze(0)) 2150 else: 2151 return torch.mm(a, a.unsqueeze(0)) 2152 2153 for test in (test1, test2): 2154 try: 2155 test(is_sparse=False) 2156 except RuntimeError as msg: 2157 with self.assertRaisesRegex(RuntimeError, re.escape(str(msg))): 2158 test(is_sparse=True) 2159 2160 @sparse_compressed_nonblock_layouts() 2161 @dtypes(torch.float, torch.double) 2162 def test_add(self, device, layout, dtype): 2163 def _test_spadd_shape(nnz, shape): 2164 # sparse.to_dense() uses torch.add internally so if torch.add is wrong, 2165 # the dense tensor will be wrong but this test would still pass 2166 # there's a separate test that checks for the correctness of the .to_dense() call 2167 x = self.genSparseCompressedTensor(shape, nnz, 2168 dtype=dtype, 2169 device=device, 2170 index_dtype=torch.int32, 2171 layout=layout, 2172 blocksize=()) 2173 y = torch.randn(*shape, dtype=dtype, device=device) 2174 r = random.random() 2175 2176 res = torch.add(y, x, alpha=r) 2177 expected = y + r * x.to_dense() 2178 self.assertEqual(res, expected) 2179 res_perm = torch.add(x, y, alpha=r) 2180 self.assertEqual(res_perm, expected) 2181 2182 # Non contiguous dense tensor 2183 s = list(shape) 2184 s[0] = shape[-1] 2185 s[-1] = shape[0] 2186 y = torch.randn(*s, dtype=torch.double, device=device) 2187 y.transpose_(0, len(s) - 1) 2188 r = random.random() 2189 2190 res = torch.add(y, x, alpha=r) 2191 expected = y + r * x.to_dense() 2192 res_perm = torch.add(x, y, alpha=r) 2193 2194 self.assertEqual(res, expected) 2195 self.assertEqual(res_perm, expected) 2196 2197 2198 ns = [2, 5] 2199 batch_shapes = [(), (2,), (2, 3)] 2200 for b, m, n in itertools.product(batch_shapes, ns, ns): 2201 _test_spadd_shape(0, (*b, m, n)) 2202 _test_spadd_shape(m * n // 2, (*b, m, n)) 2203 _test_spadd_shape(m * n, (*b, m, n)) 2204 2205 @dtypes(torch.float, torch.double) 2206 def test_mul(self, device, dtype): 2207 # TODO: This whole test should be migrated to OpInfos 2208 def _test_spadd_shape(fn, nnz, shape): 2209 x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) 2210 y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) 2211 2212 # Forward comparison 2213 res_sparse_sparse = fn(y, x) 2214 res_dense_sparse = fn(y.to_dense(), x) 2215 res_sparse_dense = fn(y, x.to_dense()) 2216 expected = fn(y.to_dense(), x.to_dense()) 2217 self.assertEqual(res_sparse_sparse, expected) 2218 # TODO: While result of mul(dense, csr) is csr, it is not fully compressed. 2219 # That means it may contain materialized zeros, since the dense argument 2220 # is converted according to the sparsity pattern of csr. In the future 2221 # we might require the result to be fully compressed. 2222 self.assertEqual(res_dense_sparse, expected) 2223 self.assertEqual(res_sparse_dense, expected) 2224 2225 # Grad comparison 2226 x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) 2227 y = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) 2228 z = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32) 2229 2230 # csr * csr -> csr with csr, csr gradients 2231 x_a = x.clone().requires_grad_() 2232 y_a = y.clone().requires_grad_() 2233 2234 fn(y_a, x_a).backward(z) 2235 2236 x_dense_a = x.to_dense().requires_grad_() 2237 y_dense_a = y.to_dense().requires_grad_() 2238 2239 fn(y_dense_a, x_dense_a).backward(z.to_dense()) 2240 2241 self.assertEqual(x_a.grad.layout, torch.sparse_csr) 2242 self.assertEqual(y_a.grad.layout, torch.sparse_csr) 2243 2244 self.assertEqual(x_a.grad.to_dense(), x_dense_a.grad) 2245 self.assertEqual(y_a.grad.to_dense(), y_dense_a.grad) 2246 2247 # TODO: Currently strided Tensors cannot have csr gradients 2248 # dense * csr -> csr with csr, dense gradients 2249 x_a = x.clone().requires_grad_() 2250 y_a = y.to_dense().clone().requires_grad_() 2251 err_msg = "Function MulBackward0 returned an invalid gradient at index 0 - expected layout Strided but got SparseCsr" 2252 with self.assertRaisesRegex(RuntimeError, err_msg): 2253 fn(y_a, x_a).backward(z) 2254 2255 # csr * dense -> csr with dense, csr gradients 2256 x_a = x.to_dense().clone().requires_grad_() 2257 y_a = y.clone().requires_grad_() 2258 err_msg = "Function MulBackward0 returned an invalid gradient at index 1 - expected layout Strided but got SparseCsr" 2259 with self.assertRaisesRegex(RuntimeError, err_msg): 2260 fn(y_a, x_a).backward(z) 2261 2262 _test_spadd_shape(torch.mul, 100, [100, 100]) 2263 _test_spadd_shape(torch.mul, 0, [100, 100]) 2264 _test_spadd_shape(torch.mul, 100, [100, 1]) 2265 _test_spadd_shape(torch.mul, 100, [1, 100]) 2266 2267 # TODO: enable hybrid once to_dense supports it 2268 @parametrize('enable_hybrid', [False]) 2269 @all_sparse_compressed_layouts() 2270 @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half)) 2271 def test_mul_scalar(self, layout, device, dtype, enable_hybrid): 2272 for sparse in self.generate_simple_inputs( 2273 layout, device=device, dtype=dtype, index_dtype=torch.int32, enable_hybrid=enable_hybrid): 2274 for scalar_dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half): 2275 # ComplexHalf is experimental 2276 if dtype is torch.half and scalar_dtype.is_complex: 2277 continue 2278 2279 scalar_t = torch.tensor(2, dtype=scalar_dtype) 2280 for scalar in (scalar_t, scalar_t.item()): 2281 res_out = sparse.mul(scalar) 2282 self.assertEqual(res_out, scalar * sparse) 2283 2284 res_dense_out = sparse.to_dense().mul(scalar) 2285 # BUG: dispatcher ignores mul.Scalar(Tensor, Scalar) 2286 # This issues is circumvented in the mul(Tensor, Tensor) kernel. 2287 self.assertEqual(res_out, res_dense_out) 2288 2289 if dtype == torch.result_type(sparse, scalar): 2290 res_in_dense = sparse.to_dense().mul_(scalar) 2291 res_in = sparse.clone().mul_(scalar) 2292 self.assertEqual(res_in, res_in_dense) 2293 self.assertEqual(res_out, res_in) 2294 2295 @skipCPUIfNoMklSparse 2296 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 2297 def test_sparse_add(self, device, dtype): 2298 def run_test(m, n, index_dtype): 2299 2300 alpha = random.random() 2301 nnz1 = random.randint(0, m * n) 2302 nnz2 = random.randint(0, m * n) 2303 nnz3 = random.randint(0, m * n) 2304 2305 if TEST_WITH_ROCM: 2306 # ROCm fails when nnz = 0 2307 nnz1, nnz2, nnz3 = max(1, nnz1), max(1, nnz2), max(1, nnz3) 2308 2309 S1 = self.genSparseCSRTensor([m, n], nnz1, dtype=dtype, device=device, index_dtype=index_dtype) 2310 S2 = self.genSparseCSRTensor([m, n], nnz2, dtype=dtype, device=device, index_dtype=index_dtype) 2311 S3 = self.genSparseCSRTensor([m, n], nnz3, dtype=dtype, device=device, index_dtype=index_dtype) 2312 sparse_args = [S1, S2, S3] 2313 dense_args = [t.to_dense() for t in sparse_args] 2314 arg_idx = list(range(len(sparse_args))) 2315 out_idx = arg_idx + [None] 2316 2317 for idx1, idx2, idx3 in itertools.product(arg_idx, arg_idx, out_idx): 2318 s1 = sparse_args[idx1] 2319 s2 = sparse_args[idx2] 2320 s3 = None if idx3 is None else sparse_args[idx3] 2321 d1 = dense_args[idx1] 2322 d2 = dense_args[idx2] 2323 d3 = None if idx3 is None else dense_args[idx3] 2324 2325 expected = torch.add(d1, d2, alpha=alpha, out=d3) 2326 actual = torch.add(s1, s2, alpha=alpha, out=s3) 2327 self.assertEqual(actual.crow_indices().dtype, index_dtype) 2328 self.assertEqual(actual.col_indices().dtype, index_dtype) 2329 self.assertEqual(actual, expected) 2330 self.assertEqual(s3, d3) 2331 if s3 is not None: 2332 self.assertEqual(s3.crow_indices().dtype, index_dtype) 2333 self.assertEqual(s3.col_indices().dtype, index_dtype) 2334 2335 for index_dtype in [torch.int32, torch.int64]: 2336 for m, n in itertools.product([3, 5], [3, 5]): 2337 run_test(m, n, index_dtype) 2338 2339 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 2340 def test_sparse_add_errors(self, device, dtype): 2341 def run_test(index_type): 2342 a = self.genSparseCSRTensor((2, 2), 3, dtype=dtype, device=device, index_dtype=index_dtype) 2343 b = self.genSparseCSRTensor((2, 1), 2, dtype=dtype, device=device, index_dtype=index_dtype) 2344 with self.assertRaisesRegex(RuntimeError, "Expected input tensors to have the same shape"): 2345 torch.add(a, b) 2346 2347 for index_dtype in [torch.int32, torch.int64]: 2348 run_test(index_dtype) 2349 2350 @skipCPUIfNoMklSparse 2351 @skipCUDAIf( 2352 not _check_cusparse_triangular_solve_available(), 2353 "cuSparse Generic API SpSV is not available" 2354 ) 2355 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 2356 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2357 torch.float64: 1e-8, torch.complex128: 1e-8}) 2358 def test_sparse_triangular_solve(self, device, dtype): 2359 2360 def run_test(n, k, upper, unitriangular, transpose, zero): 2361 if not unitriangular: 2362 triangle_function = torch.triu if upper else torch.tril 2363 else: 2364 # Make sure diagonal elements are not materialized. 2365 # This is to exercise `unitriangular=True` not relying on 2366 # explicit presence of these indices. 2367 if upper: 2368 def remove_diagonal(t): 2369 return t.triu(-1) 2370 2371 else: 2372 def remove_diagonal(t): 2373 return t.tril(-1) 2374 2375 triangle_function = remove_diagonal 2376 2377 make_A = torch.zeros if zero else make_tensor 2378 A = make_A((n, n), dtype=dtype, device=device) 2379 A = triangle_function(A) 2380 A_sparse = A.to_sparse_csr() 2381 B = make_tensor((n, k), dtype=dtype, device=device) 2382 2383 expected = torch.triangular_solve(B, A, upper=upper, unitriangular=unitriangular, transpose=transpose) 2384 expected_X = expected.solution 2385 2386 actual = torch.triangular_solve(B, A_sparse, upper=upper, unitriangular=unitriangular, transpose=transpose) 2387 actual_X = actual.solution 2388 actual_A_clone = actual.cloned_coefficient 2389 self.assertTrue(actual_A_clone.numel() == 0) 2390 if A_sparse._nnz() == 0: 2391 self.assertTrue(actual_X.isnan().all()) 2392 return 2393 self.assertEqual(actual_X, expected_X) 2394 2395 # test out with C contiguous strides 2396 out = torch.empty_strided((n, k), (k, 1), dtype=dtype, device=device) 2397 torch.triangular_solve( 2398 B, A_sparse, 2399 upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone) 2400 ) 2401 self.assertEqual(out, expected_X) 2402 2403 # test out with F contiguous strides 2404 out = torch.empty_strided((n, k), (1, n), dtype=dtype, device=device) 2405 torch.triangular_solve( 2406 B, A_sparse, 2407 upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone) 2408 ) 2409 self.assertEqual(out, expected_X) 2410 self.assertEqual(out.stride(), (1, n)) 2411 2412 # test out with discontiguous strides 2413 out = torch.empty_strided((2 * n, k), (1, 2 * n), dtype=dtype, device=device)[::2] 2414 if n > 0 and k > 0: 2415 self.assertFalse(out.is_contiguous()) 2416 self.assertFalse(out.t().is_contiguous()) 2417 before_stride = out.stride() 2418 torch.triangular_solve( 2419 B, A_sparse, 2420 upper=upper, unitriangular=unitriangular, transpose=transpose, out=(out, actual_A_clone) 2421 ) 2422 self.assertEqual(out, expected_X) 2423 self.assertEqual(out.stride(), before_stride) 2424 2425 ks = [0, 1, 3] 2426 ns = [5, 3, 0] 2427 for (k, n), (upper, unitriangular, transpose, zero) in itertools.product(itertools.product(ks, ns), 2428 itertools.product([True, False], repeat=4)): 2429 run_test(n, k, upper, unitriangular, transpose, zero) 2430 2431 @skipCUDAIf( 2432 not _check_cusparse_sddmm_available(), 2433 "cuSparse Generic API SDDMM is not available" 2434 ) 2435 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 2436 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2437 torch.float64: 1e-8, torch.complex128: 1e-8}) 2438 def test_sampled_addmm(self, device, dtype): 2439 def run_test(c, a, b, op_a, op_b, *, alpha=None, beta=None): 2440 if dtype.is_complex: 2441 alpha = random.random() + 0.3j if alpha is None else alpha 2442 beta = random.random() + 0.6j if beta is None else beta 2443 else: 2444 alpha = random.random() if alpha is None else alpha 2445 beta = random.random() if beta is None else beta 2446 2447 if op_a and a.shape == b.shape: 2448 a = a.mH 2449 if op_b and a.shape == b.shape: 2450 b = b.mH 2451 2452 actual = torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta) 2453 2454 out = torch.sparse_csr_tensor( 2455 *map(torch.clone, (actual.crow_indices(), actual.col_indices())), 2456 torch.empty_like(actual.values()), 2457 size=actual.shape 2458 ) 2459 torch.sparse.sampled_addmm(c, a, b, alpha=alpha, beta=beta, out=out) 2460 2461 spy_c = torch.sparse_csr_tensor(c.crow_indices(), c.col_indices(), torch.ones_like(c.values()), size=c.shape) 2462 expected = alpha * (a @ b) * spy_c.to_dense() + beta * c.to_dense() 2463 self.assertEqual(actual.to_dense(), out.to_dense()) 2464 self.assertEqual(actual.to_dense(), expected) 2465 2466 mnk = list(itertools.product([2, 5], repeat=3)) 2467 2468 # Add a test case for size 0 a and b tensors 2469 mnk = mnk + [(5, 5, 0)] 2470 2471 batch_shapes = [(), (2,), (2, 3)] 2472 tf = [True, False] 2473 for index_dtype in [torch.int32, torch.int64]: 2474 for (m, n, k), b, noncontiguous, bcast_c in itertools.product(mnk, batch_shapes, tf, tf): 2475 if bcast_c and len(b) == 0: 2476 continue 2477 nnz = random.randint(0, m * n) 2478 c_batch = () if bcast_c else b 2479 c = self.genSparseCSRTensor((*c_batch, m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype) 2480 a = make_tensor((*b, m, k), dtype=dtype, device=device, noncontiguous=noncontiguous) 2481 b = make_tensor((*b, k, n), dtype=dtype, device=device, noncontiguous=noncontiguous) 2482 for op_a, op_b in itertools.product([True, False], repeat=2): 2483 run_test(c, a, b, op_a, op_b) 2484 2485 @skipCUDAIf( 2486 not _check_cusparse_sddmm_available(), 2487 "cuSparse Generic API SDDMM is not available" 2488 ) 2489 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 2490 def test_sampled_addmm_autograd(self, device, dtype): 2491 from torch.testing._internal.common_methods_invocations import sample_inputs_sparse_sampled_addmm 2492 2493 samples = list(sample_inputs_sparse_sampled_addmm(None, device, dtype, requires_grad=True)) 2494 2495 for sample, dense_covector in zip(samples, [True, False]): 2496 c = sample.input 2497 a = sample.args[0] 2498 b = sample.args[1] 2499 2500 # Compute sparse result 2501 output = torch.sparse.sampled_addmm(c, a, b, **sample.kwargs) 2502 covector = torch.randn_like(output).to_dense() if dense_covector else torch.randn_like(output) 2503 output.backward(covector) 2504 2505 # Compute dense result and compare with sparse result 2506 c1, a1, b1 = (x.detach().to_dense().requires_grad_(True) for x in [c, a, b]) 2507 dense_output = sample.kwargs['alpha'] * (a1 @ b1) * torch.ones_like(c).to_dense() + sample.kwargs['beta'] * c1 2508 self.assertEqual(output, dense_output) 2509 dense_covector = covector.to_dense() 2510 dense_output.backward(dense_covector) 2511 self.assertEqual(c.grad, c1.grad) 2512 self.assertEqual(a.grad, a1.grad) 2513 self.assertEqual(b.grad, b1.grad) 2514 2515 @onlyCUDA 2516 # It works on ROCm and CUDA issue is currently active 2517 @skipCUDAIf(not TEST_WITH_ROCM, "Causes CUDA memory exception, see https://github.com/pytorch/pytorch/issues/72177") 2518 @skipCUDAIf( 2519 not _check_cusparse_sddmm_available(), 2520 "cuSparse Generic API SDDMM is not available" 2521 ) 2522 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 2523 @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, 2524 torch.float64: 1e-8, torch.complex128: 1e-8}) 2525 def test_sampled_addmm_zero_sized(self, device, dtype): 2526 def run_test(c, a, b): 2527 actual = torch.sparse.sampled_addmm(c, a, b) 2528 self.assertEqual(actual.shape, c.shape) 2529 2530 for m, n, k in itertools.product([0, 5], repeat=3): 2531 c = torch.empty(m, n, dtype=dtype, device=device, layout=torch.sparse_csr) 2532 a = make_tensor((m, k), dtype=dtype, device=device) 2533 b = make_tensor((k, n), dtype=dtype, device=device) 2534 run_test(c, a, b) 2535 2536 @onlyCUDA 2537 @skipCUDAIf( 2538 not _check_cusparse_sddmm_available(), 2539 "cuSparse Generic API SDDMM is not available" 2540 ) 2541 @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) 2542 def test_sampled_addmm_errors(self, device, dtype): 2543 # test that the errors are the same for dense and sparse sampled versions 2544 # import re 2545 2546 # shapes must be compatible for matrix multiplication 2547 a = make_tensor((2, 3), dtype=dtype, device=device) 2548 a_sparse = a.to_sparse_csr() 2549 with self.assertRaisesRegex(RuntimeError, r"cannot be multiplied"): 2550 torch.sparse.sampled_addmm(a_sparse, a, a) 2551 2552 # mat1 must be a matrix 2553 with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to be a matrix"): 2554 torch.sparse.sampled_addmm(a_sparse, a[..., 0, :], a) 2555 2556 # mat2 must be a matrix 2557 with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to be a matrix"): 2558 torch.sparse.sampled_addmm(a_sparse, a, a[..., 0, :]) 2559 2560 a = make_tensor((2, 2), dtype=dtype, device=device) 2561 b = make_tensor((3, 3), dtype=dtype, device=device) 2562 b_sparse = b.to_sparse_csr() 2563 with self.assertRaisesRegex(RuntimeError, r"self.shape\[-2\] must match mat1.shape\[-2\]"): 2564 torch.sparse.sampled_addmm(b_sparse, a, a) 2565 2566 b = make_tensor((2, 3), dtype=dtype, device=device) 2567 b_sparse = b.to_sparse_csr() 2568 with self.assertRaisesRegex(RuntimeError, r"self.shape\[-1\] must match mat2.shape\[-1\]"): 2569 torch.sparse.sampled_addmm(b_sparse, a, a) 2570 2571 a = make_tensor((2, 2), dtype=dtype, device=device) 2572 a_sparse = a.to_sparse_csr() 2573 with self.assertRaisesRegex(RuntimeError, r"Expected mat1 to have strided layout"): 2574 torch.sparse.sampled_addmm(a_sparse, a_sparse, a_sparse) 2575 2576 with self.assertRaisesRegex(RuntimeError, r"Expected mat2 to have strided layout"): 2577 torch.sparse.sampled_addmm(a_sparse, a, a_sparse) 2578 2579 @onlyCPU 2580 @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16) 2581 @precisionOverride({torch.bfloat16: 0.01}) 2582 def test_sparse_mm_reduce_sum(self, device, dtype): 2583 def run_test(m, n, k, nnz, train): 2584 sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64) 2585 dense = sparse.to_dense() 2586 2587 mat = torch.randn(k, n, dtype=dtype) 2588 ref_mat = mat.clone() 2589 2590 if train: 2591 sparse.requires_grad_() 2592 mat.requires_grad_() 2593 dense.requires_grad_() 2594 ref_mat.requires_grad_() 2595 2596 ref_out = torch.mm(dense, ref_mat) 2597 out = torch.sparse.mm(sparse, mat, 'sum') 2598 2599 self.assertEqual(out, ref_out) 2600 2601 if train: 2602 ref_out.sum().backward() 2603 out.sum().backward() 2604 2605 grad_input = sparse.grad 2606 ref_grad_input = dense.grad 2607 grad_mat = mat.grad 2608 ref_grad_mat = ref_mat.grad 2609 2610 self.assertEqual(grad_input.to_dense(), ref_grad_input) 2611 self.assertEqual(grad_mat, ref_grad_mat) 2612 2613 run_test(4, 5, 4, 10, False) 2614 run_test(4, 4, 4, 16, True) 2615 2616 @skipIfTorchDynamo() 2617 @onlyCPU 2618 @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16) 2619 @precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01}) 2620 def test_sparse_mm_reduce(self, device, dtype): 2621 def run_test(m, n, k, nnz, reduce_type, index_dtype, train): 2622 csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype) 2623 mat = torch.randn(n, k, dtype=dtype) 2624 ref_mat = mat.clone() 2625 ref_values = csr.values().clone() 2626 2627 out_int32 = index_dtype == torch.int32 2628 coo_indices = torch._convert_indices_from_csr_to_coo( 2629 csr.crow_indices(), 2630 csr.col_indices(), 2631 out_int32=out_int32) 2632 row, col = coo_indices[0], coo_indices[1] 2633 2634 def ref(row, col, val, mat): 2635 out = torch.zeros([m, k], dtype=dtype) 2636 weight = mat.index_select(0, col) 2637 src = weight.mul(val.view(-1, 1)) 2638 index = row.view(-1, 1).expand_as(weight) 2639 index = index.to(dtype=torch.int64) 2640 # scatter_reduce expect index to be int64 2641 out.scatter_reduce_(0, index, src, reduce=reduce_type, include_self=False) 2642 return out 2643 2644 if train: 2645 csr.requires_grad_() 2646 mat.requires_grad_() 2647 ref_values.requires_grad_() 2648 ref_mat.requires_grad_() 2649 2650 ref_out = ref(row, col, ref_values, ref_mat) 2651 out = torch.sparse.mm(csr, mat, reduce_type) 2652 self.assertEqual(out, ref_out) 2653 2654 if train and dtype not in (torch.bfloat16, torch.float16): 2655 ref_out.sum().backward() 2656 out.sum().backward() 2657 2658 grad_values = csr.grad.values() 2659 grad_weight = mat.grad 2660 ref_grad_values = ref_values.grad 2661 ref_grad_weight = ref_mat.grad 2662 self.assertEqual(grad_values, ref_grad_values) 2663 self.assertEqual(grad_weight, ref_grad_weight) 2664 2665 for train in [False, True]: 2666 for index_dtype in [torch.int32, torch.int64]: 2667 for reduce_type in ["sum", "mean", "amax", "amin"]: 2668 # by setting nnz < M, create empty rows 2669 run_test(3, 4, 11, 1, reduce_type, index_dtype, train) 2670 run_test(3, 4, 11, 6, reduce_type, index_dtype, train) 2671 run_test(3, 4, 11, 12, reduce_type, index_dtype, train) 2672 # we are doing blocking with 4x vector length in the kernel, 2673 # so need to test when K > 4x vector length 2674 run_test(4, 7, 33, 13, reduce_type, index_dtype, train) 2675 2676 @skipMeta 2677 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 2678 def test_coo_csr_conversion(self, device, dtype): 2679 for m, n in itertools.product([5, 2, 0], [5, 2, 0]): 2680 size = (m, n) 2681 dense = make_tensor(size, dtype=dtype, device=device) 2682 coo_sparse = dense.to_sparse() 2683 csr_sparse = coo_sparse.to_sparse_csr() 2684 2685 self.assertEqual(csr_sparse.to_dense(), dense) 2686 2687 @skipMeta 2688 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 2689 def test_csr_coo_conversion(self, device, dtype): 2690 for m, n in itertools.product([5, 2, 0], [5, 2, 0]): 2691 size = (m, n) 2692 dense = make_tensor(size, dtype=dtype, device=device) 2693 csr_sparse = dense.to_sparse_csr() 2694 coo_sparse = csr_sparse.to_sparse() 2695 2696 self.assertEqual(coo_sparse.to_dense(), dense) 2697 2698 # Currently, there is no rule in PyTorch for filling zeros in the outputs 2699 # from operations on Sparse CSR tensors. Hence only those operators are supported 2700 # which have 0->0 correspondence, example: sin(0) = 0, tan(0) = 0 but 2701 # cos(0) = 1 (and hence it's not supported). 2702 # Note: here, we do this test only for unary operators 2703 @ops(sparse_csr_unary_ufuncs) 2704 def test_zero_to_zero_correspondence_unary(self, device, dtype, op): 2705 zero = torch.zeros((1, 2), dtype=dtype, device=device) 2706 tensor_explicit_zeros = torch.sparse_csr_tensor([0, 1], [1], [0], dtype=dtype, device=device) 2707 2708 output_zero = op(zero) 2709 expected_zero = zero.to(output_zero.dtype) 2710 2711 output_explicit_zeros = op(tensor_explicit_zeros).to_dense() 2712 expected_explicit_zeros = tensor_explicit_zeros.to_dense().to(output_explicit_zeros.dtype) 2713 2714 for (output, expected) in [ 2715 (output_zero, expected_zero), 2716 (output_explicit_zeros, expected_explicit_zeros) 2717 ]: 2718 self.assertEqual(output, expected, f"This operator ({op.name}) should not be supported for " 2719 "Sparse CSR as it breaks 0->0 correspondence.") 2720 2721 for inp in [zero.to_sparse_csr(), tensor_explicit_zeros]: 2722 self.assertEqual(op(inp).values().numel(), inp.values().numel(), 2723 f"{op.name} fails to preserve sparsity pattern.") 2724 2725 @ops(sparse_csr_unary_ufuncs) 2726 def test_sparse_csr_unary_out(self, device, dtype, op): 2727 samples = op.sample_inputs(device, dtype) 2728 2729 if not op.supports_out: 2730 self.skipTest("Skipped! Out not supported") 2731 2732 for sample in samples: 2733 assert torch.is_tensor(sample.input) 2734 # Sparse CSR only supports 2D tensors as inputs 2735 # Fail early to prevent silent success with this test 2736 if sample.input.ndim != 2: 2737 raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.") 2738 2739 sample.input = sample.input.to_sparse_csr() 2740 expect = op(sample.input, *sample.args, **sample.kwargs) 2741 2742 out = self.genSparseCSRTensor(sample.input.size(), sample.input._nnz(), 2743 device=sample.input.device, dtype=expect.dtype, 2744 index_dtype=sample.input.crow_indices().dtype) 2745 op(sample.input, *sample.args, **sample.kwargs, out=out) 2746 2747 self.assertEqual(out, expect) 2748 2749 @ops(sparse_csr_unary_ufuncs) 2750 def test_sparse_csr_unary_inplace(self, device, dtype, op): 2751 samples = op.sample_inputs(device, dtype) 2752 2753 if op.inplace_variant is None: 2754 self.skipTest("Skipped! Inplace variant not supported!") 2755 2756 for sample in samples: 2757 assert torch.is_tensor(sample.input) 2758 # Sparse CSR only supports 2D tensors as inputs 2759 # Fail early to prevent silent success with this test 2760 if sample.input.ndim != 2: 2761 raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.") 2762 2763 sample.input = sample.input.to_sparse_csr() 2764 expect = op(sample.input, *sample.args, **sample.kwargs) 2765 2766 if not torch.can_cast(expect.dtype, dtype): 2767 with self.assertRaisesRegex(RuntimeError, "result type"): 2768 op.inplace_variant(sample.input, *sample.args, **sample.kwargs) 2769 continue 2770 2771 if sample.input.is_complex() and op.name == "abs": 2772 with self.assertRaisesRegex(RuntimeError, "not supported"): 2773 op.inplace_variant(sample.input, *sample.args, **sample.kwargs) 2774 continue 2775 2776 actual = op.inplace_variant(sample.input, *sample.args, **sample.kwargs) 2777 2778 self.assertIs(actual, sample.input) 2779 self.assertEqual(actual, expect) 2780 2781 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 2782 @ops(sparse_csr_unary_ufuncs, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]) 2783 def test_autograd_sparse_csr_unary(self, device, dtype, op): 2784 if op.name not in UNARY_EWISE_CSR_ALLOW_AUTOGRAD: 2785 self.skipTest(f"Skipped! Unary op {op.name} not supported with CSR input and autograd") 2786 2787 samples = list(op.sample_inputs(device, dtype)) 2788 2789 # Fail early to prevent silent success with this test 2790 ndims_equals_2d = (s.input.ndim == 2 for s in samples) 2791 if not any(ndims_equals_2d): 2792 raise ValueError("Expected at least one 2D tensor in samples.") 2793 2794 for sample in samples: 2795 # We must skip samples of low dimensionality, we can't covert them to sparsed compressed layouts 2796 if sample.input.ndim < 2: 2797 continue 2798 sparse_input = sample.input.to_sparse_csr().requires_grad_(True) 2799 2800 def fn(input): 2801 output = op.gradcheck_wrapper(op.get_op(), input, *sample.args, **sample.kwargs) 2802 if sample.output_process_fn_grad is not None: 2803 return sample.output_process_fn_grad(output) 2804 return output 2805 2806 # Compute sparse result 2807 output = fn(sparse_input) 2808 covector = torch.randn_like(output) 2809 output.backward(covector) 2810 self.assertTrue(torch.is_tensor(sparse_input.grad)) 2811 self.assertTrue(sparse_input.grad.is_sparse_csr) 2812 2813 # Compute dense result and compare with sparse result 2814 dense_input = sparse_input.detach().to_dense().requires_grad_(True) 2815 dense_output = fn(dense_input) 2816 dense_covector = covector.to_dense() 2817 dense_output.backward(dense_covector) 2818 self.assertEqual(sparse_input.grad, dense_input.grad) 2819 2820 @skipCUDAIf( 2821 not _check_cusparse_sddmm_available(), 2822 "cuSparse Generic API SDDMM is not available" 2823 ) 2824 @dtypes(torch.float64) 2825 def test_autograd_dense_output_addmm(self, device, dtype): 2826 from torch.testing._internal.common_methods_invocations import sample_inputs_addmm 2827 2828 samples = list(sample_inputs_addmm(None, device, dtype, requires_grad=True)) 2829 2830 # Fail early to prevent silent success with this test 2831 ndims_equals_2d = (s.args[0].ndim == 2 for s in samples) 2832 if not any(ndims_equals_2d): 2833 raise ValueError("Expected at least one 2D tensor in samples to convert to sparse.") 2834 2835 for sample in samples: 2836 a = sample.args[0].relu().to_sparse_csr() 2837 if sample.args[0].shape == sample.args[1].shape: 2838 import warnings 2839 warnings.warn("Broken for square matrices, see https://github.com/pytorch/pytorch/issues/116565") 2840 continue 2841 2842 # This path tests the autograd path wrt dense inputs 2843 for addmm in [torch.addmm, torch.sparse.addmm]: 2844 2845 def fn(c, b): 2846 output = addmm(c, a, b, **sample.kwargs) 2847 if sample.output_process_fn_grad is not None: 2848 return sample.output_process_fn_grad(output) 2849 return output 2850 2851 self.assertTrue(torch.autograd.gradcheck(fn, [sample.input, sample.args[1]], fast_mode=True)) 2852 2853 # noncontiguous 2854 c = make_tensor(sample.input.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True) 2855 b = make_tensor(sample.args[1].shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True) 2856 self.assertTrue(torch.autograd.gradcheck(fn, [c, b], fast_mode=True)) 2857 2858 # Now test the autograd path wrt sparse inputs 2859 for reverse in [True, False]: 2860 c, b = sample.input, sample.args[1] 2861 if reverse and a.shape != b.shape: 2862 continue 2863 2864 def fn(a): 2865 inputs = (c, b, a) if reverse else (c, a, b) 2866 output = addmm(*inputs, **sample.kwargs) 2867 if sample.output_process_fn_grad is not None: 2868 return sample.output_process_fn_grad(output) 2869 return output 2870 2871 # gradcheck doesn't work for sparse CSR yet, compare against dense path 2872 # Compute sparse result 2873 a = a.detach().requires_grad_(True) 2874 output = fn(a) 2875 covector = torch.randn_like(output) 2876 output.backward(covector) 2877 self.assertTrue(torch.is_tensor(a.grad)) 2878 if addmm == torch.sparse.addmm: 2879 self.assertTrue(a.grad.is_sparse_csr) 2880 else: 2881 self.assertTrue(a.grad.layout == torch.strided) 2882 2883 # Compute dense result and compare with sparse result 2884 dense_a = a.detach().to_dense().requires_grad_(True) 2885 dense_output = fn(dense_a) 2886 self.assertEqual(output, dense_output) 2887 dense_covector = covector.to_dense() 2888 dense_output.backward(dense_covector) 2889 2890 if addmm == torch.sparse.addmm: 2891 self.assertEqual(a.grad, dense_a.grad.sparse_mask(a)) 2892 else: 2893 self.assertEqual(a.grad, dense_a.grad) 2894 2895 @skipCPUIfNoMklSparse 2896 @dtypes(torch.float64) 2897 def test_autograd_dense_output_addmv(self, device, dtype): 2898 from torch.testing._internal.common_methods_invocations import sample_inputs_addmv 2899 2900 samples = list(sample_inputs_addmv(None, device, dtype, requires_grad=True)) 2901 2902 # Fail early to prevent silent success with this test 2903 ndims_equals_2d = (s.args[0].ndim == 2 for s in samples) 2904 if not any(ndims_equals_2d): 2905 raise ValueError("Expected at least one 2D tensor in samples to convert to sparse.") 2906 2907 for sample in samples: 2908 # TODO: Remove detach once we have autograd support for CSR input 2909 a = sample.args[0].to_sparse_csr().detach() 2910 2911 def fn(c, b): 2912 output = torch.addmv(c, a, b, **sample.kwargs) 2913 if sample.output_process_fn_grad is not None: 2914 return sample.output_process_fn_grad(output) 2915 return output 2916 2917 self.assertTrue(torch.autograd.gradcheck(fn, [sample.input, sample.args[1]], fast_mode=True)) 2918 2919 # noncontiguous 2920 c = make_tensor(sample.input.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True) 2921 b = make_tensor(sample.args[1].shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True) 2922 self.assertTrue(torch.autograd.gradcheck(fn, [c, b], fast_mode=True)) 2923 2924 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 2925 @ops(binary_ops_with_dense_output, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, ]) 2926 def test_autograd_dense_output(self, device, dtype, op): 2927 if op.name == "mv" and no_mkl_sparse and self.device_type == 'cpu': 2928 self.skipTest("MKL Sparse is not available") 2929 2930 samples = list(op.sample_inputs(device, dtype, requires_grad=True)) 2931 2932 # Fail early to prevent silent success with this test 2933 ndims_equals_2d = (s.input.ndim == 2 for s in samples) 2934 if not any(ndims_equals_2d): 2935 raise ValueError("Expected at least one 2D tensor in samples.") 2936 2937 # Here we assume that the signature is op(sparse_input, dense_input) -> dense_output 2938 for sample in samples: 2939 # TODO: Remove detach once we have autograd support for CSR input 2940 sparse_input = sample.input.to_sparse_csr().detach() 2941 2942 def fn(*args): 2943 output = op.gradcheck_wrapper(op.get_op(), sparse_input, *args, **sample.kwargs) 2944 if sample.output_process_fn_grad is not None: 2945 return sample.output_process_fn_grad(output) 2946 return output 2947 2948 self.assertTrue(torch.autograd.gradcheck(fn, sample.args, fast_mode=True)) 2949 2950 # noncontiguous 2951 args = [make_tensor(a.shape, device=device, dtype=dtype, noncontiguous=True, requires_grad=True) for a in sample.args] 2952 self.assertTrue(torch.autograd.gradcheck(fn, args, fast_mode=True)) 2953 2954 @dtypes(*all_types_and_complex()) 2955 def test_direct_coo_csr_conversion(self, device, dtype): 2956 for m, n in itertools.product([5, 2, 0], [5, 2, 0]): 2957 size = (m, n) 2958 dense = make_tensor(size, dtype=dtype, device=device) 2959 coo_sparse = dense.to_sparse_coo() 2960 2961 self.assertEqual(coo_sparse.to_sparse_csr().to_sparse_coo(), coo_sparse) 2962 2963 @skipMeta 2964 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 2965 def test_sum(self, device, dtype): 2966 def run_test(shape, nnz, index_type): 2967 a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype) 2968 self.assertEqual(a.sum(), a.values().sum()) 2969 if dtype in floating_types(): 2970 a.requires_grad_(True) 2971 a.sum().backward() 2972 self.assertEqual(a.grad, torch.ones(shape, dtype=dtype, device=device)) 2973 for shape, index_dtype in itertools.product( 2974 [(10, 5), (10, 10)], 2975 [torch.int32, torch.int64]): 2976 run_test(shape, 0, index_dtype) 2977 run_test(shape, max(shape), index_dtype) 2978 run_test(shape, shape[0] * shape[1], index_dtype) 2979 2980 @skipIfTorchDynamo() 2981 @skipMeta 2982 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 2983 @all_sparse_compressed_layouts() 2984 def test_transpose(self, device, dtype, layout): 2985 2986 def _check_transpose_view(subject, transpose): 2987 self.assertTrue(transpose.values()._is_view()) 2988 self.assertTrue(transpose._is_view()) 2989 self.assertTrue(transpose._base is subject) 2990 2991 def _check_layout_invariants(transpose): 2992 self.assertEqual(transpose.device, torch.device(device)) 2993 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[transpose.layout] 2994 compressed_indices, plain_indices = compressed_indices_mth(transpose), plain_indices_mth(transpose) 2995 torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, transpose.values(), 2996 transpose.shape, transpose.layout) 2997 2998 def check_good_transpose(subject, subject_dense, dim0, dim1, expected_layout): 2999 transpose = subject.transpose(dim0, dim1) 3000 # correct layout 3001 self.assertEqual(transpose.layout, expected_layout) 3002 # transpose must be return a view 3003 _check_transpose_view(subject, transpose) 3004 # result uses unsafe construction, so we check invariants 3005 _check_layout_invariants(transpose) 3006 self.assertEqual(transpose.to_dense(), subject_dense.transpose(dim0, dim1)) 3007 3008 round_trip = transpose.transpose(dim0, dim1) 3009 self.assertEqual(round_trip.layout, subject.layout) 3010 # transpose must be return a view 3011 _check_transpose_view(subject, round_trip) 3012 # result uses unsafe construction, so we check invariants 3013 _check_layout_invariants(round_trip) 3014 self.assertEqual(round_trip.to_dense(), subject_dense) 3015 3016 def check_same_dim_transpose(subject, subject_dense, dim): 3017 transpose = subject.transpose(dim, dim) 3018 # correct layout 3019 self.assertEqual(transpose.layout, subject.layout) 3020 # transpose must be return a view 3021 _check_transpose_view(subject, transpose) 3022 # result uses unsafe construction, so we check invariants 3023 _check_layout_invariants(transpose) 3024 self.assertEqual(transpose.to_dense(), subject_dense) 3025 3026 def check_dim_type_mismatch_throws(subject, name0, dim0, name1, dim1): 3027 mismatch_name = f"{dim0}\\({name0}\\) and {dim1}\\({name1}\\)" 3028 err = r"transpose\(\): can only transpose dimensions of the same type \(Batch, Sparse, Dense\), got " + mismatch_name 3029 3030 with self.assertRaisesRegex(RuntimeError, err): 3031 subject.transpose(dim0, dim1) 3032 3033 def run_test(shape, nnz, index_type, n_dense, blocksize=()): 3034 subject = self.genSparseCompressedTensor(shape, 3035 nnz, 3036 layout=layout, 3037 device=device, 3038 index_dtype=index_type, 3039 blocksize=blocksize, 3040 dense_dims=n_dense, 3041 dtype=dtype) 3042 3043 3044 sparse0 = len(shape) - n_dense - 1 3045 sparse1 = sparse0 - 1 3046 3047 dense0 = sparse0 + 1 if n_dense > 0 else None 3048 dense1 = dense0 + 1 if n_dense > 1 else None 3049 3050 n_batch = len(shape) - n_dense - 2 3051 batch0 = sparse1 - 1 if n_batch > 0 else None 3052 batch1 = 0 if n_batch > 1 else None 3053 3054 sparse_dims = (sparse0, sparse1) 3055 dense_dims = (dense0, dense1) 3056 batch_dims = (batch0, batch1) 3057 3058 named0 = [(name, d[0]) for name, d in zip(["Batch", "Sparse", "Dense"], (batch_dims, sparse_dims, dense_dims))] 3059 named1 = [(name, d[1]) for name, d in zip(["Batch", "Sparse", "Dense"], (batch_dims, sparse_dims, dense_dims))] 3060 3061 flipped_layout = { 3062 torch.sparse_csr: torch.sparse_csc, 3063 torch.sparse_csc: torch.sparse_csr, 3064 torch.sparse_bsr: torch.sparse_bsc, 3065 torch.sparse_bsc: torch.sparse_bsr 3066 }[layout] 3067 if n_dense > 0: 3068 # expect all transpose to throw 3069 for (name0, dim0), (name1, dim1) in itertools.product(named0, named1): 3070 msg = r"transpose\(\): hybrid sparse compressed tensors with dense dimensions are not supported" 3071 if (dim0 is not None) and (dim1 is not None): 3072 with self.assertRaisesRegex(RuntimeError, msg): 3073 subject.transpose(dim0, dim1) 3074 else: 3075 subject_dense = subject.to_dense() 3076 for (name0, dim0), (name1, dim1) in itertools.product(named0, named1): 3077 if dim0 is not None: 3078 check_same_dim_transpose(subject, subject_dense, dim0) 3079 3080 if dim1 is not None: 3081 if name0 == name1: 3082 expected_layout = flipped_layout if name0 == "Sparse" else layout 3083 check_good_transpose(subject, subject_dense, dim0, dim1, expected_layout) 3084 else: 3085 check_dim_type_mismatch_throws(subject, name0, dim0, name1, dim1) 3086 3087 # batch/sparse, sparse/dense only and full hybrid cases 3088 shape_ndense = list(itertools.product([(2, 4, 6, 2), (10, 6, 4, 2), (2, 4, 4, 2, 6)], [0, 1, 2])) 3089 # sparse only cases 3090 shape_ndense += [[(4, 8), 0], [(2, 2), 0], [(8, 4), 0]] 3091 for (shape, n_dense), index_dtype in itertools.product(shape_ndense, [torch.int32, torch.int64]): 3092 n_batch = len(shape) - n_dense - 2 3093 sparse_shape = shape[n_batch: n_batch + 2] 3094 if layout in (torch.sparse_bsr, torch.sparse_bsc): 3095 # for blocked all combinations of 2,1 should be valid blocksizes 3096 run_test(shape, 0, index_dtype, n_dense, blocksize=(2, 2)) 3097 run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(2, 2)) 3098 run_test(shape, sparse_shape[0] * sparse_shape[1], index_dtype, n_dense, blocksize=(2, 2)) 3099 # repeat the realistic sparseity case with varried block sizes 3100 run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(2, 1)) 3101 run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(1, 2)) 3102 run_test(shape, max(sparse_shape), index_dtype, n_dense, blocksize=(1, 1)) 3103 else: 3104 run_test(shape, 0, index_dtype, n_dense) 3105 run_test(shape, max(sparse_shape), index_dtype, n_dense) 3106 run_test(shape, sparse_shape[0] * sparse_shape[1], index_dtype, n_dense) 3107 3108 # TODO: This is a stopgap for a rigorous extension of our autograd tests 3109 # to test the functionality of detach 3110 @skipMeta 3111 @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) 3112 def test_exercise_detach(self, device, dtype): 3113 shape = (3, 3) 3114 nnz = 4 3115 for index_dtype in [torch.int32, torch.int64]: 3116 inp = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype) 3117 detached_inp = inp.detach() 3118 self.assertEqual(inp, detached_inp) 3119 3120 def _construct_sp_matrix(self, tensor, layout, blocksize=(2, 2)): 3121 if tensor.layout in [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc, torch.strided]: 3122 tensor = tensor.to_dense() 3123 else: 3124 raise NotImplementedError(repr(tensor)) 3125 if layout is torch.sparse_csr: 3126 return sp.csr_matrix(tensor.cpu().numpy()) 3127 if layout is torch.sparse_csc: 3128 return sp.csc_matrix(tensor.cpu().numpy()) 3129 if layout is torch.sparse_bsr: 3130 return sp.bsr_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices() 3131 if layout is torch.sparse_bsc: 3132 # SciPy doesn't have native BSC support - but our tests don't need the full 3133 # functionality so fake it by using a transposed BSR matrix. 3134 class FakeBscMatrix: 3135 def __init__(self, matrix): 3136 self._matrix = matrix 3137 self.shape = tuple(reversed(matrix.shape)) 3138 self.indptr = matrix.indptr 3139 self.indices = matrix.indices 3140 self.data = [x.transpose() for x in matrix.data] 3141 3142 @staticmethod 3143 def from_matrix(matrix, blocksize): 3144 blocksize = tuple(reversed(blocksize)) 3145 matrix = matrix.transpose() 3146 return FakeBscMatrix(sp.bsr_matrix(matrix, blocksize=blocksize)) 3147 3148 def sorted_indices(self): 3149 sub = self._matrix.sorted_indices() 3150 return FakeBscMatrix(sub) 3151 3152 return FakeBscMatrix.from_matrix(tensor.cpu().numpy(), blocksize=blocksize).sorted_indices() 3153 raise NotImplementedError(repr(tensor)) 3154 3155 @skipMeta 3156 @all_sparse_compressed_layouts('to_layout') 3157 @all_sparse_compressed_layouts('from_layout') 3158 def test_compressed_layout_conversions_coverage(self, device, from_layout, to_layout): 3159 """This test performs a smoke test for covered conversion and verifies 3160 that an exception is thrown for unsupported conversions. 3161 3162 TODO: This test covers a subset of 3163 TestSparseAny.test_to_sparse tests and can be 3164 eliminated. Keeping the test until the new 3165 `Tensor.to_sparse(*, layout, blocksize)` has landed. 3166 """ 3167 3168 allowed_pairwise_layouts_sets = { 3169 frozenset({torch.sparse_csc}), 3170 frozenset({torch.sparse_csr}), 3171 frozenset({torch.sparse_csc, torch.sparse_csr}), 3172 frozenset({torch.sparse_csc, torch.sparse_bsc}), 3173 frozenset({torch.sparse_csc, torch.sparse_bsr}), 3174 frozenset({torch.sparse_csr, torch.sparse_bsc}), 3175 frozenset({torch.sparse_csr, torch.sparse_bsr}), 3176 frozenset({torch.sparse_bsc}), 3177 frozenset({torch.sparse_bsr}), 3178 frozenset({torch.sparse_bsc, torch.sparse_bsr}), 3179 } 3180 block_layouts = (torch.sparse_bsr, torch.sparse_bsc) 3181 3182 def _to_from_layout(layout_a, layout_b, a): 3183 expect_error = True 3184 if {layout_a, layout_b} in allowed_pairwise_layouts_sets: 3185 expect_error = False 3186 3187 # BSR -> CSR is not yet supported 3188 if (layout_a, layout_b) == (torch.sparse_bsr, torch.sparse_csr): 3189 expect_error = True 3190 # BSR -> CSC is not yet supported 3191 if (layout_a, layout_b) == (torch.sparse_bsr, torch.sparse_csc): 3192 expect_error = True 3193 # BSC -> CSR is not yet supported 3194 if (layout_a, layout_b) == (torch.sparse_bsc, torch.sparse_csr): 3195 expect_error = True 3196 # BSC -> CSC is not yet supported 3197 if (layout_a, layout_b) == (torch.sparse_bsc, torch.sparse_csc): 3198 expect_error = True 3199 # CSR -> BSR only works for non-batched inputs 3200 if (layout_a, layout_b) == (torch.sparse_csr, torch.sparse_bsr): 3201 if a.dim() > 2: 3202 expect_error = True 3203 # CSR -> BSC only works for non-batched inputs 3204 if (layout_a, layout_b) == (torch.sparse_csr, torch.sparse_bsc): 3205 if a.dim() > 2: 3206 expect_error = True 3207 # CSC -> BSR only works for non-batched inputs 3208 if (layout_a, layout_b) == (torch.sparse_csc, torch.sparse_bsr): 3209 if a.dim() > 2: 3210 expect_error = True 3211 # CSC -> BSC only works for non-batched inputs 3212 if (layout_a, layout_b) == (torch.sparse_csc, torch.sparse_bsc): 3213 if a.dim() > 2: 3214 expect_error = True 3215 3216 blocksize_a = (1, 1) if layout_a in {torch.sparse_bsr, torch.sparse_bsc} else None 3217 blocksize_b = (1, 1) if layout_b in {torch.sparse_bsr, torch.sparse_bsc} else None 3218 b = a.to_sparse(layout=layout_a, blocksize=blocksize_a) 3219 if expect_error: 3220 with self.assertRaises(RuntimeError): 3221 b.to_sparse(layout=layout_b, blocksize=blocksize_b) 3222 else: 3223 c = b.to_sparse(layout=layout_b, blocksize=blocksize_b) 3224 self.assertEqual(a.to_dense(), c.to_dense()) 3225 3226 # change of blocksize upon conversion is not yet supported. 3227 if b.layout in block_layouts: 3228 for block_layout in block_layouts: 3229 with self.assertRaisesRegex(RuntimeError, 3230 "conversion from.*to.*with blocksize changed from.*to.*is not supported"): 3231 b.to_sparse(layout=block_layout, blocksize=(3, 3)) 3232 3233 batch_dims = [(), (2,), (2, 2), (2, 2, 2)] 3234 sparse_dims = (6, 12) 3235 for batch_dim in batch_dims: 3236 a = make_tensor(batch_dim + sparse_dims, dtype=torch.float, device=device) 3237 _to_from_layout(from_layout, to_layout, a) 3238 3239 @skipMeta 3240 @all_sparse_compressed_layouts() 3241 @batched_nonbatched() 3242 @hybrid_nonhybrid() 3243 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 3244 def test_dense_to_from_sparse_compressed(self, device, hybrid, batched, layout): 3245 """This test tests conversion from dense to/from CSR and CSC 3246 by comparing to SciPy's implementation. 3247 3248 Here we test only those conversion combinations that SciPy 3249 supports to ensure that PyTorch conversions are in the same 3250 page with SciPy. Independent from SciPy, all conversion 3251 combinations are tested in TestSparseAny.test_to_sparse. 3252 """ 3253 3254 blocked_layouts = (torch.sparse_bsr, torch.sparse_bsc) 3255 3256 # helpers 3257 3258 def _check_against_scipy_matrix(pt_matrix, dense, blocksize, **kwargs): 3259 # scipy has no bsc layout, so we check against the bsr layout of the tranposed dense 3260 if layout == torch.sparse_bsc: 3261 sp_matrix = self._construct_sp_matrix(dense.t(), layout=torch.sparse_bsr, blocksize=blocksize[::-1]) 3262 else: 3263 sp_matrix = self._construct_sp_matrix(dense, layout=layout, blocksize=blocksize) 3264 3265 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] 3266 3267 self.assertEqual(layout, pt_matrix.layout) 3268 if layout == torch.sparse_bsc: 3269 self.assertEqual(sp_matrix.shape[::-1], pt_matrix.shape) 3270 else: 3271 self.assertEqual(sp_matrix.shape, pt_matrix.shape) 3272 3273 self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix)) 3274 self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix)) 3275 if layout == torch.sparse_bsc: 3276 # we must tranpose the blocks before comparing 3277 self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values().transpose(-2, -1)) 3278 else: 3279 self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values()) 3280 3281 def _check_hybrid_matrix(pt_matrix, dense, blocksize, **kwargs): 3282 # Calculate COO indices for sparse matrix. 3283 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] 3284 compressed_indices = compressed_indices_mth(pt_matrix) 3285 plain_indices = plain_indices_mth(pt_matrix) 3286 coo_indices = torch._convert_indices_from_csr_to_coo(compressed_indices, plain_indices) 3287 row_indices, col_indices = { 3288 torch.sparse_csr: (coo_indices[0, ], coo_indices[1, ]), 3289 torch.sparse_csc: (coo_indices[1, ], coo_indices[0, ]), 3290 torch.sparse_bsr: (coo_indices[0, ], coo_indices[1, ]), 3291 torch.sparse_bsc: (coo_indices[1, ], coo_indices[0, ]), 3292 }[pt_matrix.layout] 3293 3294 # If sparse matrix layout blocked, rearrange dense matrix 3295 # so that the shape past first two dimensions match the 3296 # shape of sparse matrix values. 3297 dense_to_check = dense 3298 if blocksize: 3299 dense_shape = dense.shape 3300 dense_to_check_shape = (dense.shape[0] // blocksize[0], 3301 blocksize[0], 3302 dense.shape[1] // blocksize[1], 3303 blocksize[1]) + dense.shape[2:] 3304 dense_to_check = dense_to_check.reshape(dense_to_check_shape).transpose(1, 2) 3305 3306 # Verify that non-zero values of the sparse matrix are 3307 # equal to corresponding values of the dense matrix. 3308 self.assertEqual(pt_matrix.values(), dense_to_check[row_indices, col_indices]) 3309 3310 # Verify that the remaining elements of the dense matrix 3311 # are 0, i.e. that dense are sparse matrix are fully 3312 # equal. 3313 mask = torch.ones_like(dense_to_check, dtype=torch.bool) 3314 mask[row_indices, col_indices] = False 3315 self.assertTrue(torch.all(torch.masked_select(dense_to_check, mask) == 0)) 3316 3317 def _check_batched(pt_tensor, dense, check_batch=None, batch_shape=(), blocksize=(), **kwargs): 3318 self.assertEqual(layout, pt_tensor.layout) 3319 self.assertEqual(pt_tensor.shape, dense.shape) 3320 compressed_indices_mth, plain_indices_mth = sparse_compressed_indices_methods[layout] 3321 for batch_index in np.ndindex(batch_shape): 3322 pt_matrix = pt_tensor[batch_index] 3323 dense_matrix = dense[batch_index] 3324 dense_dim = pt_matrix.dim() - 2 3325 dense_matrix_pt = dense_matrix.to_sparse(layout=layout, 3326 blocksize=blocksize or None, 3327 dense_dim=dense_dim) 3328 # sanity check, selecting batch of to_<layout> and dense[batch].to_<layout> should give the same result 3329 self.assertEqual(pt_matrix, dense_matrix_pt) 3330 check_batch(pt_matrix, dense_matrix, blocksize, **kwargs) 3331 3332 def _generate_subject(sparse_shape, batch_shape, hybrid_shape): 3333 shape = batch_shape + sparse_shape + hybrid_shape 3334 n_batch_dim = len(batch_shape) 3335 n_hybrid_dim = len(hybrid_shape) 3336 # generate a dense tensor 3337 dense = make_tensor(shape, dtype=torch.float, device=device) 3338 3339 # introduce some sparsty, mask is sparse shape, element applies to entire dense sub-tensor (hybrid) and is 3340 # applied to each batch 3341 mask = make_tensor(sparse_shape, dtype=torch.bool, device=device) 3342 # manually expand to match hybrid shape 3343 if hybrid: 3344 mask = mask.view(sparse_shape + tuple(1 for _ in range(n_hybrid_dim))) 3345 mask = mask.expand(sparse_shape + hybrid_shape) 3346 3347 # mask will broadcast over the batch dims if present 3348 3349 return dense * mask 3350 3351 # note: order is important here, the hybrid-ness decides the inner content check which is used to build the 3352 # batched checker (if needed) 3353 check_content = _check_against_scipy_matrix 3354 if hybrid: 3355 check_content = _check_hybrid_matrix 3356 if batched: 3357 check_content = functools.partial(_check_batched, check_batch=check_content) 3358 3359 sparse_sizes = [(6, 10), (0, 10), (6, 0), (0, 0)] 3360 blocksizes = [(2, 2), (1, 1), (1, 2)] if layout in blocked_layouts else [()] 3361 batch_sizes = [(3,), (1, 3), (2, 1, 3)] if batched else [()] 3362 hybrid_sizes = [(4, ), (2, 2)] if hybrid else [()] 3363 3364 # general cases, always run 3365 for sparse_shape, blocksize, batch_shape, hybrid_shape in itertools.product( 3366 sparse_sizes, blocksizes, batch_sizes, hybrid_sizes): 3367 dense = _generate_subject(sparse_shape, batch_shape, hybrid_shape) 3368 sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None, dense_dim=len(hybrid_shape)) 3369 check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape) 3370 dense_back = sparse.to_dense() 3371 self.assertEqual(dense, dense_back) 3372 3373 # special cases for batched tensors 3374 if batched: 3375 # batched sparse tensors need only have the same number of non-zeros in each batch not nessesarily the 3376 # same sparsity pattern in each batch 3377 sparse_shape = sparse_sizes[0] 3378 hybrid_shape = hybrid_sizes[0] 3379 batch_shape = batch_sizes[0] 3380 shape = batch_shape + sparse_shape + hybrid_shape 3381 dense = make_tensor(shape, dtype=torch.float, device=device) 3382 blocksize = blocksizes[0] 3383 # number of elements/blocks in each batch (total not nnz) 3384 batch_mask_shape = sparse_shape 3385 if layout in blocked_layouts: 3386 # if we are blocked the mask is genereated for the block valued elemetns 3387 batch_mask_shape = sparse_shape[0] // blocksize[0], sparse_shape[1] // blocksize[1] 3388 3389 # random bool vector w/ length equal to max possible nnz for the sparse_shape 3390 mask_source = make_tensor(batch_mask_shape, dtype=torch.bool, device=device).flatten() 3391 n_batch = functools.reduce(operator.mul, batch_shape, 1) 3392 3393 # stack random permutations of the source for each batch 3394 mask = torch.stack([mask_source[torch.randperm(mask_source.numel())] 3395 for _ in range(n_batch)], dim=0).reshape(batch_shape + batch_mask_shape) 3396 if layout in blocked_layouts: 3397 # for blocked we need to do a bit of extra work to expand the mask from blocked-space to element-space 3398 mask_shape = mask.shape 3399 mask = mask.view(mask_shape + (1, 1)) 3400 mask = mask.expand(mask_shape + blocksize) 3401 mask = mask.transpose(-3, -2) 3402 mask = mask.flatten(-4, -3).flatten(-2, -1) 3403 mask_shape = mask.shape 3404 mask = mask.view(mask_shape + (1,) * len(hybrid_shape)) 3405 mask = mask.expand(mask_shape + hybrid_shape) 3406 dense = dense * mask 3407 sparse = dense.to_sparse(layout=layout, blocksize=blocksize or None, dense_dim=len(hybrid_shape)) 3408 check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape) 3409 3410 dense_back = sparse.to_dense() 3411 self.assertEqual(dense, dense_back) 3412 3413 # if batches have different nnz we expect the conversion to throw 3414 mask_0 = mask[0] 3415 mask_1 = mask[0].clone().fill_(True) 3416 mask_2 = mask[0].clone().fill_(False) 3417 mask_true = mask_source.clone().fill_(True) 3418 mask_false = mask_source.clone().fill_(False) 3419 mask = torch.stack([(mask_0, mask_1, mask_2)[i % 3] for i in range(n_batch)], dim=0).reshape(batch_shape + mask_0.shape) 3420 dense = make_tensor(shape, dtype=torch.float, device=device) 3421 dense = dense * mask 3422 msg = "Expect the same number of specified elements per batch." 3423 with self.assertRaisesRegex(RuntimeError, msg): 3424 dense.to_sparse(layout=layout, blocksize=blocksize or None) 3425 3426 # Should throw if there is a zero in the batch size 3427 dense = make_tensor((0,) + shape, dtype=torch.float, device=device) 3428 layout_code = str(layout).split("_")[-1] 3429 msg = f"to_sparse_{layout_code}: Expected product of batch dimensions to be non-zero." 3430 with self.assertRaisesRegex(RuntimeError, msg): 3431 dense.to_sparse(layout=layout, blocksize=blocksize or None) 3432 3433 @skipMeta 3434 @all_sparse_compressed_layouts() 3435 @coalescedonoff 3436 @dtypes(torch.double) 3437 @unittest.skipIf(not TEST_SCIPY, "SciPy not found") 3438 def test_sparse_to_sparse_compressed(self, device, dtype, coalesced, layout): 3439 """ 3440 This test tests conversion from COO to CSR and CSC and CSC to CSR and CSC 3441 by comparing to SciPy's implementation. 3442 3443 Here we test only those conversion combinations that SciPy 3444 supports to ensure that PyTorch conversions are in the same 3445 page with SciPy. Independent from SciPy, all conversion 3446 combinations are tested in TestSparseAny.test_to_sparse. 3447 """ 3448 3449 blocksize_kw = {} 3450 if layout in (torch.sparse_bsc, torch.sparse_bsr): 3451 blocksize_kw['blocksize'] = (2, 2) 3452 # block modes don't support 0 width/height 3453 shapes = [(6, 10)] 3454 elif layout in (torch.sparse_csc, torch.sparse_csr): 3455 shapes = [(0, 10), (6, 0), (6, 10), (0, 0)] 3456 else: 3457 raise NotImplementedError("unhandled layout") 3458 3459 if layout in (torch.sparse_bsc, torch.sparse_csc): 3460 compressed_indices_mth = torch.Tensor.ccol_indices 3461 plain_indices_mth = torch.Tensor.row_indices 3462 elif layout in (torch.sparse_bsr, torch.sparse_csr): 3463 compressed_indices_mth = torch.Tensor.crow_indices 3464 plain_indices_mth = torch.Tensor.col_indices 3465 else: 3466 raise NotImplementedError("unhandled layout") 3467 3468 for shape in shapes: 3469 sparse_dim = 2 3470 nnz = shape[0] * shape[1] // 2 3471 sparse, _, _ = self.genSparseTensor(shape, sparse_dim, nnz, coalesced, device, dtype) 3472 sp_matrix = self._construct_sp_matrix(sparse, layout) 3473 pt_matrix = sparse.to_sparse(layout=layout, **blocksize_kw) 3474 3475 self.assertEqual(layout, pt_matrix.layout) 3476 self.assertEqual(sp_matrix.shape, pt_matrix.shape) 3477 self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix)) 3478 self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix)) 3479 self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values()) 3480 3481 sparse_csc = sparse.to_sparse_csc() 3482 sp_matrix = self._construct_sp_matrix(sparse_csc, layout) 3483 pt_matrix = sparse_csc.to_sparse(layout=layout, **blocksize_kw) 3484 3485 self.assertEqual(layout, pt_matrix.layout) 3486 self.assertEqual(sp_matrix.shape, pt_matrix.shape) 3487 self.assertEqual(torch.tensor(sp_matrix.indptr, dtype=torch.int64), compressed_indices_mth(pt_matrix)) 3488 self.assertEqual(torch.tensor(sp_matrix.indices, dtype=torch.int64), plain_indices_mth(pt_matrix)) 3489 self.assertEqual(torch.tensor(sp_matrix.data), pt_matrix.values()) 3490 3491 @unittest.skipIf(not TEST_CUDA_CUDSS, "The test requires cudss") 3492 @dtypes(*floating_types()) 3493 def test_linalg_solve_sparse_csr_cusolver(self, device, dtype): 3494 # https://github.com/krshrimali/pytorch/blob/f5ee21dd87a7c5e67ba03bfd77ea22246cabdf0b/test/test_sparse_csr.py 3495 3496 try: 3497 spd = torch.rand(4, 3) 3498 A = spd.T @ spd 3499 b = torch.rand(3).cuda() 3500 A = A.to_sparse_csr().cuda() 3501 x = torch.sparse.spsolve(A, b) 3502 except RuntimeError as e: 3503 if "Calling linear solver with sparse tensors requires compiling " in str(e): 3504 self.skipTest("PyTorch was not built with cuDSS support") 3505 3506 samples = sample_inputs_linalg_solve(None, device, dtype) 3507 3508 for sample in samples: 3509 if sample.input.ndim != 2: 3510 continue 3511 3512 out = torch.zeros(sample.args[0].size(), dtype=dtype, device=device) 3513 if sample.args[0].ndim != 1 and sample.args[0].size(-1) != 1: 3514 with self.assertRaisesRegex(RuntimeError, "b must be a 1D tensor"): 3515 out = torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs) 3516 break 3517 if not sample.args[0].numel(): 3518 with self.assertRaisesRegex(RuntimeError, 3519 "Expected non-empty other tensor, but found empty tensor"): 3520 torch.linalg.solve(sample.input.to_sparse_csr(), *sample.args, **sample.kwargs, out=out) 3521 break 3522 3523 expect = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs) 3524 sample.input = sample.input.to_sparse_csr() 3525 if sample.args[0].ndim != 1 and sample.args[0].size(-1) == 1: 3526 expect = expect.squeeze(-1) 3527 sample.args = (sample.args[0].squeeze(-1), ) 3528 out = torch.linalg.solve(sample.input, *sample.args, **sample.kwargs) 3529 self.assertEqual(expect, out) 3530 3531 3532def skipIfNoTriton(cls): 3533 from torch.utils._triton import has_triton 3534 3535 # no-op if triton is present 3536 if has_triton(): 3537 return cls 3538 else: 3539 3540 @functools.wraps(cls, updated=()) 3541 class skipped_cls(cls): 3542 def setUp(self): 3543 self.skipTest("Triton is not available.") 3544 3545 return skipped_cls 3546 3547@skipIfNoTriton 3548class TestSparseCompressedTritonKernels(TestCase): 3549 3550 def _to_block_triangular_inplace(self, d, row_block, col_block): 3551 """ 3552 This function modifies `d` to become (upper/lower) block-triangular in-place. 3553 It is assumed that `d.shape[-2]` is divisible by `row_block` and 3554 `d.shape[-1]` is divisible by `col_block`. 3555 """ 3556 3557 from torch.sparse._triton_ops import tile_to_blocksize 3558 3559 m, n = d.shape[-2:] 3560 d_tiled = tile_to_blocksize(d, (row_block, col_block)) 3561 d_tiled = d_tiled.moveaxis(-4, -1).moveaxis(-4, -1) 3562 if m // row_block > n // col_block: 3563 d_tiled.tril_() 3564 else: 3565 d_tiled.triu_() 3566 3567 return d 3568 3569 @onlyCUDA 3570 @skipIfRocm(msg="test is too slow on ROCm stack") 3571 @dtypes(torch.half, torch.bfloat16, torch.float) 3572 @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) 3573 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") 3574 def test_triton_bsr_softmax(self, device, dtype): 3575 from functools import partial 3576 from torch.sparse._triton_ops import bsr_softmax 3577 3578 tensor = partial(make_tensor, device=device, dtype=dtype, low=1.0, high=3.0) 3579 3580 # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`. 3581 batches = [(), (2,), (2, 2)] 3582 size = [6, 12, 0] 3583 block_size = [2, 3] 3584 3585 # General correctness 3586 for row_block, col_block, b, m, n in itertools.product(block_size, block_size, batches, size, size): 3587 input = tensor(b + (m, n)) 3588 input.diagonal(dim1=-2, dim2=-1).fill_(m * n) 3589 input = self._to_block_triangular_inplace(input, row_block, col_block) 3590 3591 bsr = input.to_sparse_bsr((row_block, col_block)) 3592 coo = input.to_sparse().to(torch.float) 3593 3594 res_tri = bsr_softmax(bsr) 3595 res_coo = torch.sparse.softmax(coo, -1) 3596 self.assertEqual(res_tri, res_coo.to(input.dtype)) 3597 3598 # Test long rows which exceed Triton's max numel limit set to 2 ** 17 3599 input = tensor(b + (1, 150000)) 3600 bsr = input.to_sparse_bsr(1) 3601 self.assertEqual(input.softmax(-1), bsr_softmax(bsr)) 3602 3603 @parametrize("block_size", [16, 32, 64]) 3604 @parametrize("index_dtype", [torch.int32, torch.int64]) 3605 @onlyCUDA 3606 @dtypes(torch.half, torch.bfloat16, torch.float) 3607 @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) 3608 @unittest.skipIf((not TEST_WITH_TORCHINDUCTOR) or (IS_FBCODE and IS_REMOTE_GPU) or torch._running_with_deploy(), 3609 "Skipped for deploy and internal with remote GPUs") 3610 def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size): 3611 from functools import partial 3612 from torch.sparse._triton_ops import bsr_dense_mm 3613 3614 def kernel_impl(*args, **kwargs): 3615 return bsr_dense_mm(*args, skip_checks=True, **kwargs) 3616 3617 kernel = torch._TritonLibrary.registerOp( 3618 "_triton_bsr_dense_mm_out", 3619 "_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)", 3620 kernel_impl, 3621 "SparseCsrCUDA" 3622 ) 3623 3624 # kernel != kernel_impl means dispatch was already registered. 3625 # This is exactly what we need! 3626 self.assertTrue(kernel is not kernel_impl) 3627 3628 # Note that each value in a non-zero block is in range block_size * [low^2, high^2). 3629 tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5) 3630 3631 # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`. 3632 batches = [(), (2,), (2, 2)] 3633 size = [128, 256, 0] 3634 3635 # Whether to make inputs orthogonal so that the product is zero 3636 make_orthogonal = [True, False] 3637 3638 for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal): 3639 bsr = tensor(bs + (m, k)) 3640 # NOTE: do not get confused, it will be transposed 3641 dense = tensor(bd + (n, k)) 3642 3643 if is_ortho: 3644 bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1) 3645 dense = torch.cat((torch.zeros_like(dense), dense), dim=-1) 3646 3647 bsr = bsr.to_sparse_bsr(block_size) 3648 3649 if bsr.dim() == 2 and dtype != torch.float: 3650 # Test against linear to check dispatch 3651 # which takes place for torch.half and torch.bfloat16. 3652 res_dense = torch.nn.functional.linear(dense, bsr.to_dense()) 3653 res_tri_out = torch.empty_like(res_dense) 3654 res_tri = torch.nn.functional.linear(dense, bsr, out=res_tri_out) 3655 3656 # Check dispatch worked with non-trivial outputs 3657 if m > 0 and n > 0 and k > 0: 3658 self.assertTrue(kernel.kernel_invoked) 3659 kernel.kernel_invoked = False 3660 else: 3661 # Otherwise check correctness against bmm 3662 # since nn.linear does not support bsr.dim() > 2. 3663 res_dense = bsr.to_dense() @ dense.transpose(-2, -1) 3664 res_tri_out = torch.empty_like(res_dense) 3665 res_tri = kernel(bsr, dense.transpose(-2, -1), out=res_tri_out) 3666 3667 self.assertTrue(res_tri is res_tri_out) 3668 self.assertEqual(res_tri, res_dense) 3669 3670 res_dense = bsr.to_dense() @ dense.transpose(-2, -1) 3671 # check whether bsr_dense_mm handles different grid sizes 3672 # None means max possible grid size which is CUDA-dependent. 3673 grid_size = (None, 2, 4) 3674 grid_gen = itertools.product(grid_size, repeat=3) 3675 for grid in grid_gen: 3676 res_tri = torch.sparse._triton_ops.bsr_dense_mm( 3677 bsr, 3678 dense.transpose(-2, -1), 3679 max_grid=grid, 3680 ) 3681 self.assertEqual(res_tri, res_dense) 3682 3683 @onlyCUDA 3684 @dtypes(torch.half) 3685 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU or torch._running_with_deploy(), 3686 "Skipped for deploy and internal with remote GPUs") 3687 def test_triton_bsr_dense_bmm_error_messages(self, device, dtype): 3688 from torch.sparse._triton_ops import bsr_dense_mm 3689 3690 rhs = torch.rand(32, 32, dtype=dtype, device=device) 3691 lhs = rhs.to_sparse_bsr(16) 3692 with self.assertRaisesRegex(ValueError, "only BSR sparse format is supported"): 3693 bsr_dense_mm(lhs.to_sparse_bsc(16), rhs) 3694 with self.assertRaisesRegex(ValueError, "on the same GPU device"): 3695 bsr_dense_mm(lhs, rhs.cpu()) 3696 if torch.cuda.device_count() > 1: 3697 with self.assertRaisesRegex(ValueError, "on the same GPU device"): 3698 bsr_dense_mm(lhs.to("cuda:0"), rhs.to("cuda:1")) 3699 with self.assertRaisesRegex(ValueError, "all inputs are expected to be of the same dtype"): 3700 bsr_dense_mm(lhs, rhs.to(torch.float)) 3701 with self.assertRaisesRegex(ValueError, r"and one of \(half, bfloat16, float32\)"): 3702 bsr_dense_mm(lhs.to(torch.double), rhs.to(torch.double)) 3703 with self.assertRaisesRegex(ValueError, "all inputs involved in the matrix product are expected to be at least 2D"): 3704 bsr_dense_mm(lhs, torch.rand(1, dtype=dtype, device=device)) 3705 with self.assertRaisesRegex(ValueError, 3706 "sizes involved in the matrix product are not compatible for matrix multiplication"): 3707 bsr_dense_mm(lhs, torch.rand(1, 1, dtype=dtype, device=device)) 3708 with self.assertRaisesRegex(ValueError, 3709 r"dense.size\(-1\) == 15 should be divisible by 16"): 3710 bsr_dense_mm(lhs, torch.rand(32, 15, dtype=dtype, device=device)) 3711 # Blocksizes check 3712 for blocksize in (15, 30): 3713 n = blocksize * 2 3714 rhs = torch.rand(n, n, dtype=dtype, device=device) 3715 lhs = rhs.to_sparse_bsr(blocksize) 3716 with self.assertRaisesRegex(ValueError, "should be at least 16 and a power of 2"): 3717 bsr_dense_mm(lhs, rhs) 3718 # out check 3719 rhs = torch.rand(2, 32, 32, dtype=dtype, device=device) 3720 lhs = rhs.to_sparse_bsr(16) 3721 with self.assertRaisesRegex(ValueError, r"`out` argument has wrong shape"): 3722 out = torch.rand(2, 30, 30, dtype=dtype, device=device) 3723 bsr_dense_mm(lhs, rhs, out=out) 3724 with self.assertRaisesRegex(ValueError, r"only row-major/col-major `out`"): 3725 out = torch.rand(32, 32, 2, dtype=dtype, device=device).transpose(0, -1) 3726 bsr_dense_mm(lhs, rhs, out=out) 3727 3728 @parametrize("block_size", [16, 32, 64]) 3729 @onlyCUDA 3730 @skipIfRocm 3731 @dtypes(torch.half, torch.bfloat16, torch.float) 3732 @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) 3733 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") 3734 @precisionOverride({torch.float16: 1e-3}) 3735 def test_triton_scaled_dot_product_attention(self, device, dtype, block_size): 3736 from functools import partial 3737 from torch.sparse._triton_ops import _scaled_dot_product_attention 3738 3739 # Note that each value in a non-zero block is in range block_size * [low^2, high^2). 3740 tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2) 3741 3742 def broadcast_input(*ts): 3743 batch_dims = torch.broadcast_shapes(*(t.shape[:-2] for t in ts)) 3744 yield from (torch.broadcast_to(t, batch_dims + t.shape[-2:]) for t in ts) 3745 3746 # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`. 3747 batches = [(), (2,), (2, 2)] 3748 size = [128, 256, 0] 3749 3750 for bam, bq, bk, bv, m, n, k in itertools.product(batches, batches, batches, batches, size, size, size): 3751 query = tensor(bq + (m, k)) 3752 key = tensor(bk + (n, k)) 3753 value = tensor(bv + (n, k)) 3754 3755 # We make attn_mask block lower/upper triangular so that BSR and Strided 3756 # function variants are directly comparable. 3757 attn_mask = torch.ones(bam + (m, n), device=device, dtype=torch.bool) 3758 attn_mask = self._to_block_triangular_inplace(attn_mask, block_size, block_size) 3759 attn_mask_bsr = attn_mask.to_sparse_bsr(block_size) 3760 3761 # NOTE: only boolean mask is directly compatible with the Strided version 3762 # without any pre-/post-processing. Hence we test against a boolean mask. 3763 for scale in (None, 1. / 16): 3764 if scale is None and query.size(-1) == 0: 3765 scale = 1 3766 expected = torch.nn.functional.scaled_dot_product_attention( 3767 *broadcast_input(query, key, value, attn_mask), scale=scale 3768 ) 3769 3770 for mask_dtype in (torch.bool, dtype): 3771 res = _scaled_dot_product_attention(query, key, value, attn_mask_bsr.to(mask_dtype), scale=scale) 3772 self.assertEqual(res, expected) 3773 3774 3775 @parametrize("block_size", [16, 32, 64]) 3776 @onlyCUDA 3777 @skipIfRocm 3778 @dtypes(torch.half, torch.bfloat16, torch.float) 3779 @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) 3780 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") 3781 def test_triton_sampled_addmm(self, device, dtype, block_size): 3782 from functools import partial 3783 from torch.sparse._triton_ops import sampled_addmm, broadcast_batch_dims_bsr 3784 3785 # Note that each value in a non-zero block is in range block_size * [low^2, high^2). 3786 tensor = partial(make_tensor, device=device, dtype=dtype, low=0.3, high=1.2) 3787 3788 # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`. 3789 batches = [(), (2,), (2, 2)] 3790 size = [128, 256, 0] 3791 3792 delta_k = (-3,) 3793 for bi, bm1, bm2, m, n, k, dk in itertools.product(batches, batches, batches, size, size, size, delta_k): 3794 # Test not powers of 2 ks as well. 3795 k = max(0, k + dk) 3796 # Non-trivial sparsity pattern. 3797 # Plus with tril inputs the result is also tril, 3798 # so we can compare BSR and CSR implementations. 3799 input = tensor(bi + (m, n)).tril_() 3800 bsr = input.to_sparse_bsr(block_size) 3801 mat1 = tensor(bm1 + (m, k)).tril_() 3802 mat2 = tensor(bm2 + (k, n)).tril_() 3803 3804 batch_dim = torch.broadcast_shapes(input.shape[:-2], mat1.shape[:-2], mat2.shape[:-2]) 3805 3806 csr = input.broadcast_to(batch_dim + input.shape[-2:]).to_sparse_csr().to(torch.float) 3807 mat1csr = mat1.broadcast_to(batch_dim + mat1.shape[-2:]).to(torch.float) 3808 mat2csr = mat2.broadcast_to(batch_dim + mat2.shape[-2:]).to(torch.float) 3809 3810 input_broadcasted_clone = broadcast_batch_dims_bsr( 3811 "test_triton_sampled_addmm", 3812 bsr, mat1, mat2 3813 ).clone() 3814 input_broadcasted_clone = torch.sparse_compressed_tensor( 3815 input_broadcasted_clone.crow_indices(), 3816 input_broadcasted_clone.col_indices(), 3817 # For testing `out=` let's make values to have "weird" strides 3818 # so that if the kernel modifies values to it's needs, the result 3819 # is being compied into out.values. 3820 input_broadcasted_clone.values().transpose(-3, -2).contiguous().transpose(-3, -2), 3821 layout=input_broadcasted_clone.layout, 3822 size=input_broadcasted_clone.shape 3823 ) 3824 3825 scalars = (0.0, 2.0) 3826 for alpha, beta, out in itertools.product(scalars, scalars, (None, input_broadcasted_clone)): 3827 res_tri = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, out=out) 3828 if out is not None: 3829 self.assertTrue(res_tri is out) 3830 3831 batch_broadcasted_shape = torch.broadcast_shapes(*(t.shape[:-2] for t in (input, mat1, mat2))) 3832 self.assertTrue(res_tri.shape == batch_broadcasted_shape + (m, n)) 3833 3834 res_csr = torch.sparse.sampled_addmm(csr, mat1csr, mat2csr, alpha=alpha, beta=beta).to(input.dtype) 3835 self.assertEqual(res_tri.to_dense(), res_csr.to_dense()) 3836 3837 # Check different grid sizes to make sure that input slicing works 3838 # if this input is larger than the grid. 3839 grid_size = (3, None) 3840 grid_gen = itertools.product(grid_size, repeat=2) 3841 for grid in grid_gen: 3842 res_tri_grid = sampled_addmm(bsr, mat1, mat2, alpha=alpha, beta=beta, max_grid=grid) 3843 self.assertEqual(res_tri, res_tri_grid) 3844 3845 @onlyCUDA 3846 @skipIfRocm 3847 @dtypes(torch.half, torch.bfloat16, torch.float) 3848 @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) 3849 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") 3850 def test_triton_scatter_mm(self, device, dtype): 3851 from torch.sparse._triton_ops import scatter_mm 3852 from functools import partial 3853 tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5) 3854 sizes = [8, 16] 3855 for m, k, n in itertools.product(sizes, sizes, sizes): 3856 blocks = torch.stack([tensor(m, k), tensor(m, k)]) 3857 others = torch.stack([tensor(k, n), tensor(k, n)]) 3858 3859 expected = torch.stack([blocks[0] @ others[0] + blocks[1] @ others[0], 3860 blocks[0] @ others[1], 3861 blocks[1] @ others[1]]) 3862 3863 indices_data = ( 3864 'scatter_mm', 3865 torch.tensor([0, 2, 3, 4], dtype=torch.int32, device=device), 3866 torch.tensor([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=torch.int32, device=device)) 3867 3868 result = scatter_mm(blocks, others, indices_data=indices_data) 3869 3870 self.assertEqual(result, expected) 3871 3872 indices_data = ( 3873 'bsr_strided_mm', 3874 torch.tensor([0, 2, 4, 5, 6], dtype=torch.int32, device=device), 3875 torch.tensor([0, n, 2 * n * m, 2 * n * m + n], dtype=torch.int32, device=device), 3876 torch.tensor([1, 0, 1, 0, 1, 1], dtype=torch.int32, device=device), 3877 torch.tensor([0, 2 * k * n, n, 2 * k * n + n, 2 * k * n, 2 * k * n + n], 3878 dtype=torch.int32, device=device), 3879 dict(SPLIT_N=2, is_compressed=False, TILE_M=m, TILE_N=n, GROUP_SIZE=1) 3880 ) 3881 3882 for bsize in [(), (2,), (3, 4)]: 3883 other = tensor(*bsize, 2 * k, 2 * n) 3884 expected = torch.cat([ 3885 torch.cat([blocks[1], blocks[0]], dim=1), 3886 torch.cat([torch.zeros_like(blocks[0]), blocks[1]], dim=1)], dim=0) @ other 3887 result = scatter_mm(blocks, other, indices_data=indices_data) 3888 self.assertEqual(result, expected) 3889 3890 @parametrize("blocksize", [2, '2x3', 16, '16x32', 32, 64]) 3891 @onlyCUDA 3892 @dtypes(torch.half, torch.bfloat16, torch.float) 3893 @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) 3894 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") 3895 def test_triton_bsr_scatter_mm(self, device, dtype, blocksize): 3896 import triton 3897 from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data 3898 from functools import partial 3899 if isinstance(blocksize, str): 3900 blocksize = tuple(map(int, blocksize.split('x'))) 3901 else: 3902 blocksize = (blocksize,) * 2 3903 # Note that each value in a non-zero block is in range blocksize * [low^2, high^2). 3904 tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5) 3905 3906 # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`. 3907 batches = [(), (2,), (2, 2)] 3908 sizes = [blocksize[0], 2 * blocksize[0], 4 * blocksize[0]] 3909 sizes_K = [blocksize[1], 2 * blocksize[1]] 3910 3911 for bd, bs, M, K, N, has_zero_row_block in itertools.product(batches, batches[:1], sizes, sizes_K, sizes, (False, True)): 3912 bsr_dense = tensor(bs + (M, K)) 3913 if has_zero_row_block: 3914 if M > blocksize[0]: 3915 bsr_dense[:blocksize[0]].zero_() 3916 else: 3917 continue 3918 bsr = bsr_dense.to_sparse_bsr(blocksize) 3919 dense = tensor(bd + (K, N)) 3920 expected = bsr.to_dense() @ dense 3921 3922 for indices_format in ('bsr_strided_mm', 'bsr_strided_mm_compressed', 'scatter_mm'): 3923 if indices_format in {'bsr_strided_mm', 'bsr_strided_mm_compressed'}: 3924 SPLIT_N_list = [N] 3925 while SPLIT_N_list[-1] > 1: 3926 SPLIT_N_list.append(max(1, SPLIT_N_list[-1] // 2)) 3927 else: 3928 SPLIT_N_list = [1] 3929 for SPLIT_N in SPLIT_N_list: 3930 indices_data = bsr_scatter_mm_indices_data( 3931 bsr, dense, indices_format=indices_format, SPLIT_N=SPLIT_N) 3932 try: 3933 result = bsr_scatter_mm(bsr, dense, indices_data=indices_data) 3934 except triton.compiler.OutOfResources: 3935 # ensure that there was at least one succesful test: 3936 assert SPLIT_N < SPLIT_N_list[0] 3937 break 3938 3939 self.assertEqual(result, expected) 3940 torch.sparse._triton_ops._bsr_scatter_mm_indices_data.cache_clear() 3941 3942 def test_TensorAsKey(self, device): 3943 from torch.sparse._triton_ops import TensorAsKey 3944 assertEqualOptions = dict(exact_dtype=True, exact_device=True, exact_layout=True) 3945 3946 t = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device=device) 3947 key = TensorAsKey(t) 3948 self.assertTrue(key == TensorAsKey(t)) 3949 self.assertTrue(key.obj is t) 3950 3951 t2 = t[:] 3952 key2 = TensorAsKey(t2) 3953 self.assertTrue(key == key2) 3954 self.assertEqual(key2.obj, t, **assertEqualOptions) 3955 # deleting object leads to dead key 3956 del t2 3957 self.assertTrue(key2.obj is None) 3958 self.assertTrue(key.obj is t) 3959 3960 # key with different storage offset and shape: 3961 self.assertFalse(key == TensorAsKey(t[1:])) 3962 3963 # key with different strides: 3964 self.assertFalse(key == TensorAsKey(t[::2])) 3965 3966 # when object dies, make sure that key represents a dead 3967 # object as well: 3968 del t 3969 self.assertTrue(key.obj is None) 3970 3971 # Storing a tensor as a dict key: 3972 d = {} 3973 t3 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device) 3974 key3 = TensorAsKey(t3) 3975 d[key3] = 123 3976 self.assertTrue(d.get(key3) == 123) 3977 t3_ = t3[:] 3978 self.assertTrue(d.get(TensorAsKey(t3_)) == 123) 3979 self.assertTrue(d.get(TensorAsKey(t3.clone())) is None) 3980 3981 d[TensorAsKey(t3_)] = 567 3982 self.assertTrue(d.get(key3) == 567) 3983 3984 # t3 and t3_ reference the same data, so, the key becomes dead 3985 # (that is, its .obj property returns None) until all 3986 # references are deleted: 3987 del t3 3988 self.assertTrue(key3.obj is not None) 3989 self.assertTrue(d.get(key3) == 567) 3990 del t3_ 3991 self.assertTrue(key3.obj is None) 3992 self.assertTrue(d.get(key3) == 567) 3993 3994 # Storing a tensor as a dict key and value: 3995 d = {} 3996 t4 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device) 3997 key4 = TensorAsKey(t4) 3998 d[key4] = (t4, 123) 3999 self.assertEqual(d.get(key4), (t4, 123), **assertEqualOptions) 4000 # when object is deleted, the key represents an alive object 4001 # because the object is referenced by the dict item value: 4002 del t4 4003 self.assertTrue(key4.obj is not None) 4004 # This also means that the life time of the tensor is same as 4005 # the life time of the corresponding dict item: 4006 del d[key4] 4007 self.assertTrue(key4.obj is None) 4008 4009 # Storing a tensor as a dict key and value wrapped with TensorAsKey: 4010 d = {} 4011 t5 = torch.tensor([1, 2, 3, 4], dtype=torch.int32, device=device) 4012 key5 = TensorAsKey(t5) 4013 d[key5] = (key5, 567) 4014 self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions) 4015 self.assertTrue(key5.obj is not None) 4016 # when object is deleted, it will be dead as the wrapped value 4017 # hold the tensor instance as a weakref: 4018 del t5 4019 self.assertTrue(key5.obj is None) 4020 # but key is still valid: 4021 self.assertEqual(d.get(key5), (key5, 567), **assertEqualOptions) 4022 4023 @suppress_warnings 4024 @parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear', '_int_bsr_dense_addmm']) 4025 @parametrize("blocksize", [16, '16x32', 32]) 4026 @onlyCUDA 4027 @skipIfRocm 4028 @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8) 4029 @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8) 4030 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") 4031 def test_triton_kernel(self, op, device, dtype, blocksize): 4032 from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm 4033 from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta, 4034 optimize_bsr_dense_addmm, dump) 4035 4036 def bsr_dense_linear(input, weights, bias=None): 4037 return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2) 4038 4039 operation = dict(bsr_dense_addmm=bsr_dense_addmm, bsr_dense_mm=bsr_dense_mm, bsr_dense_linear=bsr_dense_linear, 4040 _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op] 4041 4042 def reference(input, mat1, mat2, beta=1, alpha=1, op=op): 4043 assert mat1.layout is torch.strided 4044 assert mat2.layout is torch.strided 4045 if dtype is torch.int8: 4046 if op == '_int_bsr_dense_addmm': 4047 return beta * input + alpha * torch._int_mm(mat1, mat2) 4048 # workaround RuntimeError: "addmm_cuda" not implemented for 'Char' 4049 return beta * input + alpha * torch._int_mm(mat1, mat2).to(torch.int8) 4050 return beta * input + alpha * (mat1 @ mat2) 4051 4052 if op == '_int_bsr_dense_addmm': 4053 # _int_bsr_dense_addmm is same as bsr_dense_addmm except 4054 # with int8 inputs, _int_bsr_dense_addmm returns int32 4055 # result. This is covered by operation and reference 4056 # definitions above and all other definitions below are 4057 # identical between _int_bsr_dense_addmm and 4058 # bsr_dense_addmm. 4059 op = 'bsr_dense_addmm' 4060 4061 def nc_copy(t, axes=(-1,)): 4062 """Return a copy of input. 4063 4064 The returned copy will be a non-contiguous tensor. 4065 """ 4066 if t.layout is torch.strided: 4067 shape = list(t.shape) 4068 for a in axes: 4069 shape[a] *= 2 4070 r = torch.empty(shape, dtype=t.dtype, device=t.device) 4071 s = r[tuple(slice(None, None, 2 if t.shape[i] != r.shape[i] else None) for i in range(t.ndim))] 4072 s.copy_(t) 4073 return s 4074 elif t.layout is torch.sparse_bsr: 4075 compressed_indices = t.crow_indices() 4076 plain_indices = t.col_indices() 4077 return torch.sparse_compressed_tensor(compressed_indices, plain_indices, nc_copy(t.values()), 4078 t.shape, layout=t.layout) 4079 else: 4080 raise NotImplementedError(t.layout) 4081 4082 if isinstance(blocksize, str): 4083 BM, BK = tuple(map(int, blocksize.split('x'))) 4084 else: 4085 BM, BK = (blocksize,) * 2 4086 4087 if op in {"bsr_dense_linear"} and BM != BK: 4088 # todo: eliminate this skip 4089 self.skipTest(f"{op} does not support non-square blocks") 4090 4091 if op in {"bsr_dense_linear"} and dtype is torch.int8: 4092 # todo: eliminate this skip 4093 self.skipTest(f"{op} does not support int8") 4094 4095 if dtype is torch.int8 and min(BM, BK) < 32: 4096 self.skipTest("triton kernel does not support support int8 blocks smaller than 32") 4097 4098 beta_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[0], bsr_dense_linear=[1])[op] 4099 alpha_lst = dict(bsr_dense_addmm=[0, 1, 2], bsr_dense_mm=[1], bsr_dense_linear=[1])[op] 4100 sparsity_lst = [0, 0.5, 1] 4101 blocks_per_row_lst = [1, 2] 4102 blocks_per_col_lst = [1, 2] 4103 result_cols_lst = [16, 32, 64] 4104 for beta, alpha, sparsity, blocks_per_row, blocks_per_col, N in itertools.product( 4105 beta_lst, alpha_lst, sparsity_lst, blocks_per_row_lst, blocks_per_col_lst, result_cols_lst): 4106 M = BM * blocks_per_row 4107 K = BK * blocks_per_col 4108 mat1 = create_blocked_tensor(0, M, K, (BM, BK), sparsity, dtype, device=device) 4109 bsr = mat1.to_sparse_bsr((BM, BK)) 4110 mat2 = make_tensor(K, N, dtype=dtype, device=device, low=0.5, high=1.5) 4111 input = make_tensor(M, N, dtype=dtype, device=device, low=0.5, high=1.5) 4112 4113 if 0 and op == "bsr_dense_addmm": 4114 # Find optimal kernel parameters, the speed-up is 4115 # about 10x for running this test. 4116 # 4117 # Enable this if-block when the test method is 4118 # updated, run the test, and finally, disable the 4119 # if-block. 4120 key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1) 4121 meta = get_meta(op, key, version=(0, dtype, 0.5)) 4122 if meta is None: 4123 optimize_bsr_dense_addmm(M, K, N, BM, BK, beta=beta, alpha=alpha, dtype=dtype, sparsity=0.5) 4124 meta = get_meta(op, key, version=(0, dtype, 0.5)) 4125 assert meta is not None 4126 dump() # this will update torch/sparse/_triton_ops_meta.py 4127 4128 expected = reference(input, mat1, mat2, beta=beta, alpha=alpha) 4129 kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), bsr_dense_mm={}, 4130 bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op] 4131 4132 args = dict(bsr_dense_addmm=(input, bsr, mat2), bsr_dense_mm=(bsr, mat2), 4133 bsr_dense_linear=(mat2.transpose(-1, -2), bsr))[op] 4134 result = operation(*args, **kwargs) 4135 self.assertEqual(result, expected) 4136 4137 # Test non-contiguous input tensors: 4138 nc_mat2 = nc_copy(mat2) 4139 nc_input = nc_copy(input) 4140 nc_bsr = nc_copy(bsr) 4141 4142 args = dict(bsr_dense_addmm=(input, bsr, nc_mat2), bsr_dense_mm=(bsr, nc_mat2), 4143 bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op] 4144 result = operation(*args, **kwargs) 4145 self.assertEqual(result, expected) 4146 4147 # todo: add bsr_dense_linear to the set below (currently, 4148 # nn.linear has unnecessarily restrictive arguments 4149 # checks). 4150 if op in {'bsr_dense_addmm', 'bsr_dense_mm'}: 4151 args = dict(bsr_dense_addmm=(input, nc_bsr, mat2), bsr_dense_mm=(nc_bsr, mat2), 4152 bsr_dense_linear=(mat2.transpose(-1, -2), nc_bsr))[op] 4153 result = operation(*args, **kwargs) 4154 self.assertEqual(result, expected) 4155 4156 if op in {'bsr_dense_addmm', 'bsr_dense_linear'}: 4157 args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2), 4158 bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op] 4159 kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha), 4160 bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op] 4161 result = operation(*args, **kwargs) 4162 self.assertEqual(result, expected) 4163 4164 @parametrize("op", ['bsr_dense_addmm', '_int_bsr_dense_addmm']) 4165 @onlyCUDA 4166 @skipIfRocm 4167 @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8) 4168 @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8) 4169 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") 4170 def test_triton_tune(self, op, device, dtype): 4171 from torch.sparse._triton_ops import bsr_dense_addmm, _int_bsr_dense_addmm 4172 from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, tune__int_bsr_dense_addmm, get_meta) 4173 4174 operation = dict(bsr_dense_addmm=bsr_dense_addmm, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op] 4175 tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm, 4176 _int_bsr_dense_addmm=tune__int_bsr_dense_addmm)[op] 4177 4178 if op == '_int_bsr_dense_addmm': 4179 M, K, N = 32, 32, 32 4180 blocksize = (32, 32) 4181 else: 4182 M, K, N = 16, 16, 32 4183 blocksize = (16, 16) 4184 sparsity = 1.0 4185 bsr = create_blocked_tensor(0, M, K, blocksize, sparsity, dtype, device).to_sparse_bsr(blocksize) 4186 sparsity = 1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K) 4187 input = make_tensor(K, N, dtype=dtype, device=device) 4188 dense = make_tensor(K, N, dtype=dtype, device=device) 4189 4190 if op in {'bsr_dense_addmm', '_int_bsr_dense_addmm'}: 4191 args = (input, bsr, dense) 4192 4193 def get_current_meta(): 4194 version = (0, dtype, sparsity) 4195 meta_key = (M, K, N, *blocksize, False, True, True) 4196 return get_meta(op, meta_key, version=version, exact=True) 4197 else: 4198 raise NotImplementedError(op) 4199 4200 self.assertEqual(get_current_meta(), None) 4201 4202 meta = tuner(*args, **dict(store=True, verbose=False)) 4203 self.assertEqual(get_current_meta(), meta) 4204 4205 expected = operation(*args) 4206 result = operation(*args, **dict(meta=meta)) 4207 self.assertEqual(result, expected) 4208 4209 @onlyCUDA 4210 @skipIfRocm 4211 @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") 4212 def test_triton_bsr_dense_addmm_meta(self, device): 4213 from torch.sparse._triton_ops import bsr_dense_addmm_meta 4214 from torch.sparse._triton_ops_meta import update as update_bsr_dense_addmm_meta 4215 4216 dtype = torch.float32 4217 Ms = Ks = 16 4218 beta = 0.0 4219 alpha = 1.0 4220 4221 def get_meta(M, K, N, sparsity=None): 4222 return bsr_dense_addmm_meta(M, K, N, Ms, Ks, beta, alpha, dtype=dtype, sparsity=sparsity, 4223 _version="test_triton_bsr_dense_addmm_meta") 4224 4225 def update_meta(M, K, N, value, sparsity=0.5): 4226 key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) 4227 update_bsr_dense_addmm_meta("bsr_dense_addmm", torch.cuda.get_device_name(), 4228 ("test_triton_bsr_dense_addmm_meta", dtype, sparsity), 4229 key, value) 4230 4231 def get_meta_with_checks(M, K, N, warn_count=0, sparsity=None): 4232 f = io.StringIO() 4233 with redirect_stderr(f): 4234 result = get_meta(M, K, N, sparsity=sparsity) 4235 msg = f.getvalue() 4236 FileCheck().check_count( 4237 str=f"UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M={M} K={K} N={N}", 4238 count=warn_count, exactly=True 4239 ).run(msg) 4240 return result 4241 4242 # Test warn_once when requesting non-existing tuned parameters multiple times 4243 f = io.StringIO() 4244 with redirect_stderr(f): 4245 for i in range(5): 4246 get_meta(16, 16, 16) 4247 for i in range(5): 4248 get_meta(16, 16, 32) 4249 4250 msg = f.getvalue() 4251 FileCheck().check_count( 4252 str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=16", count=1, exactly=True 4253 ).run(msg) 4254 FileCheck().check_count( 4255 str="UserWarning: bsr_dense_addmm uses non-optimal triton kernel parameters for M=16 K=16 N=32", count=1, exactly=True 4256 ).run(msg) 4257 4258 # Test warn_once when tuned parameters are missing 4259 default_meta = dict(GROUP_SIZE_ROW=4, SPLIT_N=2, num_stages=1, num_warps=4) 4260 self.assertEqual(get_meta_with_checks(32, 32, 32, warn_count=1), default_meta) 4261 4262 # Test (no)warn_once when tuned parameters are available 4263 update_meta(32, 32, 48, (2, 8, 5, 6)) 4264 expected_meta = dict(GROUP_SIZE_ROW=2, SPLIT_N=8, num_stages=5, num_warps=6) 4265 self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0), expected_meta) 4266 4267 # Test non-existing tuned parameters with non-default sparsity 4268 # while for default sparsity 0.5 the parameters are available 4269 self.assertEqual(get_meta_with_checks(32, 32, 48, warn_count=0, sparsity=0.6), expected_meta) 4270 4271 # Test non-existing tuned parameters while there exists 4272 # parameters with consistent N // SPLIT_N ratio: 4273 self.assertEqual(get_meta_with_checks(32, 32, 72, warn_count=0), 4274 dict(GROUP_SIZE_ROW=2, SPLIT_N=12, num_stages=5, num_warps=6)) 4275 # ... or not: 4276 self.assertEqual(get_meta_with_checks(32, 32, 64, warn_count=1), 4277 dict(GROUP_SIZE_ROW=4, SPLIT_N=4, num_stages=1, num_warps=4)) 4278 4279 4280# e.g., TestSparseCSRCPU and TestSparseCSRCUDA 4281instantiate_device_type_tests(TestSparseCSR, globals()) 4282instantiate_device_type_tests(TestSparseCompressed, globals()) 4283instantiate_device_type_tests(TestSparseCompressedTritonKernels, globals()) 4284 4285if __name__ == '__main__': 4286 run_tests() 4287