1# mypy: allow-untyped-defs 2import warnings 3from collections import namedtuple 4from typing import Any, Callable, Dict, List, Optional, Tuple 5 6import torch 7from torch.sparse._semi_structured_conversions import ( 8 sparse_semi_structured_from_dense_cutlass, 9 sparse_semi_structured_to_dense_cutlass, 10) 11from torch.sparse._semi_structured_ops import ( 12 fallback_dispatcher, 13 semi_sparse_addmm, 14 semi_sparse_detach, 15 semi_sparse_indices, 16 semi_sparse_linear, 17 semi_sparse_mm, 18 semi_sparse_t, 19 semi_sparse_values, 20 semi_sparse_view, 21) 22 23 24__all__ = [ 25 "SparseSemiStructuredTensor", 26 "SparseSemiStructuredTensorCUTLASS", 27 "SparseSemiStructuredTensorCUSPARSELT", 28 "to_sparse_semi_structured", 29] 30 31_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( 32 "_SEMI_STRUCTURED_SPARSE_CONFIG", 33 "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols", 34) 35 36 37class SparseSemiStructuredTensor(torch.Tensor): 38 """ 39 This class implementes semi-structured sparsity as a Tensor subclass. 40 41 Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, 42 depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained 43 structured sparsity. 44 45 There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS. 46 This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS 47 and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items. 48 Note that as such, this class cannot be insantiated directly. 49 50 -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints 51 - `def from_dense()` - backend specific compression routines 52 - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm)) 53 """ 54 55 _DEFAULT_ALG_ID: int = 0 56 _DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG] 57 _FORCE_CUTLASS: bool = True 58 _FUSE_TRANSPOSE: bool = False 59 _PROTOTYPE_WARNING_SHOWN: bool = False 60 61 BACKEND: str 62 SPARSE_DISPATCH: Dict[Callable, Callable] 63 64 packed: Optional[torch.Tensor] 65 meta: Optional[torch.Tensor] 66 packed_t: Optional[torch.Tensor] 67 meta_t: Optional[torch.Tensor] 68 compressed_swizzled_bitmask: Optional[torch.Tensor] 69 fuse_transpose_cusparselt: bool 70 alg_id_cusparselt: int 71 72 __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"] 73 74 @staticmethod 75 def __new__( # noqa: PYI034 76 cls, 77 shape: torch.Size, 78 packed: Optional[torch.Tensor], 79 meta: Optional[torch.Tensor], 80 packed_t: Optional[torch.Tensor], 81 meta_t: Optional[torch.Tensor], 82 compressed_swizzled_bitmask: Optional[torch.Tensor], 83 fuse_transpose_cusparselt: bool = False, 84 alg_id_cusparselt: int = 0, 85 requires_grad: bool = False, 86 ): 87 """ 88 Create a new instance of the tensor subclass from the compressed sparse representation. 89 90 We have the option to create the subclass with the compressed representations of both X and X', for training. 91 For inference, we only need a single representation (either X or X'), while the corresponding other set will be None. 92 93 Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS) 94 95 Args: 96 shape: The shape of the original dense tensor 97 packed: The compressed representation of the original dense tensor 98 meta: The metadata of the original dense tensor, if it is stored separately 99 packed_t: The compressed representation of the transposed original dense tensor 100 meta_t: The metadata of the transposed original dense tensor, if it is stored separately 101 compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should 102 participate in the computation. Used for pointwise ops. 103 fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition 104 with a matmul, which is useful in the case of 2:4 sparse training. 105 alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance 106 107 Returns: 108 torch.Tensor: A torch.Tensor wrapper subclass. 109 110 Raises: 111 ValueError: If all of the tensor arguments are None. 112 """ 113 if not cls._PROTOTYPE_WARNING_SHOWN: 114 warnings.warn( 115 ( 116 "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " 117 "and will change in the near future. Please open a Github issue " 118 "for features requests and see our documentation on the torch.sparse " 119 "module for further information about the project." 120 ), 121 UserWarning, 122 ) 123 cls._PROTOTYPE_WARNING_SHOWN = True 124 125 # Because this only runs onces, we also load the dispatch table here as well. 126 # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead 127 # But this is useful since it allows users to overload the dispatch table for debugging / testing. 128 cls._load_dispatch_table() 129 130 # we can also register the classes with dynamo when the warning is shown. 131 torch._dynamo.allow_in_graph(cls) 132 133 if packed is not None: 134 previous_tensor = packed 135 elif packed_t is not None: 136 previous_tensor = packed_t 137 else: 138 raise ValueError("At least one of packed or packed_t must be provided") 139 140 kwargs = { 141 "device": previous_tensor.device, 142 "dtype": previous_tensor.dtype, 143 "layout": previous_tensor.layout, 144 "requires_grad": requires_grad, 145 } 146 tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] 147 148 tensor.packed = packed 149 tensor.meta = meta 150 tensor.packed_t = packed_t 151 tensor.meta_t = meta_t 152 tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask 153 tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt 154 tensor.alg_id_cusparselt = alg_id_cusparselt 155 return tensor 156 157 def __repr__(self) -> str: # type: ignore[override] 158 assert hasattr(self, "shape") 159 return f"{self.__class__.__name__}(shape={self.shape})" 160 161 def __tensor_flatten__( 162 self, 163 ) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]: 164 inner_tensors = list( 165 filter(lambda x: getattr(self, x) is not None, self.__slots__) 166 ) 167 tensor_meta = ( 168 self.shape, 169 self.fuse_transpose_cusparselt, 170 self.alg_id_cusparselt, 171 self.requires_grad, 172 ) 173 return inner_tensors, tensor_meta 174 175 @classmethod 176 def __tensor_unflatten__( 177 cls, 178 inner_tensors, 179 tensor_meta: Tuple[torch.Size, bool, int, bool], 180 outer_size, 181 outer_stride, 182 ) -> torch.Tensor: 183 shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta 184 return cls( 185 shape=shape, 186 packed=inner_tensors.get("packed", None), 187 meta=inner_tensors.get("meta", None), 188 packed_t=inner_tensors.get("packed_t", None), 189 meta_t=inner_tensors.get("meta_t", None), 190 compressed_swizzled_bitmask=inner_tensors.get( 191 "compressed_swizzled_bitmask", None 192 ), 193 fuse_transpose_cusparselt=fuse_transpose_cusparselt, 194 alg_id_cusparselt=alg_id_cusparselt, 195 requires_grad=requires_grad, 196 ) 197 198 __torch_function__ = torch._C._disabled_torch_function_impl 199 200 @classmethod 201 def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: 202 if func._overloadpacket not in cls.SPARSE_DISPATCH: 203 raise NotImplementedError( 204 f"{cls.__name__} only supports a specific set of operations, " 205 f"can't perform requested op ({func.__name__})" 206 ) 207 return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs) 208 209 @classmethod 210 def _load_dispatch_table(cls, custom_dispatch_table=None) -> None: 211 """ 212 Loads the op overload sparse dispatch table for the current class. 213 """ 214 if getattr(cls, "SPARSE_DISPATCH", None) is None: 215 cls.SPARSE_DISPATCH = { 216 torch.ops.aten.values: semi_sparse_values, 217 torch.ops.aten.indices: semi_sparse_indices, 218 torch.ops.aten.is_same_size: fallback_dispatcher, 219 torch.ops.aten.detach_: fallback_dispatcher, 220 torch.ops.aten.detach: semi_sparse_detach, 221 torch.ops.aten.t: semi_sparse_t, 222 torch.ops.aten.view: semi_sparse_view, 223 torch.ops.aten.mm: semi_sparse_mm, 224 torch.ops.aten.matmul: semi_sparse_mm, 225 torch.ops.aten.addmm: semi_sparse_addmm, 226 torch.ops.aten.linear: semi_sparse_linear, 227 torch.ops.aten._to_copy: fallback_dispatcher, 228 } 229 if custom_dispatch_table is not None: 230 cls.SPARSE_DISPATCH.update(custom_dispatch_table) 231 232 @classmethod 233 def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None: 234 """ 235 Assert that the given tensor is valid for semi-structured sparse compression. 236 """ 237 # check device 238 if not original_tensor.is_cuda: 239 raise RuntimeError( 240 f"Error original_tensor.device= {original_tensor.device} is not supported! " 241 "Only CUDA tensors are currently supported." 242 ) 243 244 # check dim 245 if original_tensor.dim() != 2: 246 raise RuntimeError( 247 f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " 248 "Only 2d tensors are currently supported." 249 ) 250 251 # check contiguous 252 if not original_tensor.is_contiguous(): 253 raise RuntimeError( 254 "Error original_tensor is not contiguous!" 255 "Only contiguous tensors are currently supported." 256 ) 257 258 # check dtype 259 if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS: 260 raise RuntimeError( 261 f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " 262 "dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}" 263 ) 264 265 # check shape 266 m, n = original_tensor.shape 267 min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows 268 min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols 269 if m < min_rows or m % min_rows or n < min_cols or n % min_cols: 270 # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples 271 raise RuntimeError( 272 f"Error original_tensor.shape {original_tensor.shape} is not supported! " 273 f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" 274 ) 275 276 @classmethod 277 def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: 278 """ 279 Calculates padding for dense tensor and pads tensor if necessary. 280 If padding is not required, this function returns the original tensor. 281 """ 282 # only 2d matmul 283 assert dense_input.dim() == 2 284 285 # check shape 286 m, n = dense_input.shape 287 min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows 288 min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols 289 290 # calculate padding 291 to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 292 to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 293 if to_pad_m or to_pad_n: 294 return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) 295 else: 296 return dense_input 297 298 def to_dense(self): 299 col = self.shape[-1] 300 return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) 301 302 @classmethod 303 def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor": 304 raise NotImplementedError 305 306 def _mm( 307 self, 308 B: torch.Tensor, 309 *, 310 bias: Optional[torch.Tensor] = None, 311 **kwargs, 312 ) -> torch.Tensor: 313 raise NotImplementedError 314 315 316def to_sparse_semi_structured( 317 original_tensor: torch.Tensor, 318 transposed: bool = False, 319) -> SparseSemiStructuredTensor: 320 """ 321 This function converts a dense tensor into a sparse semi-structured tensor. 322 It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. 323 324 This function will check to ensure the dense tensor has the right dtype, size, dims, and device. 325 We currently only support semi-structured sparse tensors for 2d CUDA tensors. 326 Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in 327 `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8). 328 329 Args: 330 original_tensor (Tensor): the dense tensor to convert 331 transposed (bool, optional): deprecated arg to be removed in another release. Do not use. 332 Returns: 333 SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor 334 Raises: 335 None 336 Example: 337 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 338 >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() 339 tensor([[0., 0., 1., ..., 0., 1., 1.], 340 [0., 0., 1., ..., 0., 1., 1.], 341 [0., 0., 1., ..., 0., 1., 1.], 342 ..., 343 [0., 0., 1., ..., 0., 1., 1.], 344 [0., 0., 1., ..., 0., 1., 1.], 345 [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) 346 >>> A_sparse = to_sparse_semi_structured(A) 347 SparseSemiStructuredTensor(shape=torch.Size([128, 128])) 348 >>> A_sparse.values() 349 tensor([[1., 1., 1., ..., 1., 1., 1.], 350 [1., 1., 1., ..., 1., 1., 1.], 351 [1., 1., 1., ..., 1., 1., 1.], 352 ..., 353 [1., 1., 1., ..., 1., 1., 1.], 354 [1., 1., 1., ..., 1., 1., 1.], 355 [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), 356 >>> A_sparse.indices() 357 tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], 358 [-4370, -4370, -4370, ..., -4370, -4370, -4370], 359 [-4370, -4370, -4370, ..., -4370, -4370, -4370], 360 ..., 361 [-4370, -4370, -4370, ..., -4370, -4370, -4370], 362 [-4370, -4370, -4370, ..., -4370, -4370, -4370], 363 [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16)) 364 """ 365 if transposed: 366 warnings.warn( 367 "Setting transpose from `to_sparse_semi_structured` is deprecated " 368 "and will be removed in a future release. " 369 "`SparseSemiStructuredTensor` only support contiguous input tensors.", 370 FutureWarning, 371 stacklevel=2, 372 ) 373 374 # set from _FORCE_CUTLASS flag 375 SPARSE_SUBCLASS = ( 376 torch.sparse.SparseSemiStructuredTensorCUTLASS 377 if SparseSemiStructuredTensor._FORCE_CUTLASS 378 else torch.sparse.SparseSemiStructuredTensorCUSPARSELT 379 ) 380 381 return SPARSE_SUBCLASS.from_dense(original_tensor) 382 383 384class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): 385 """ 386 This class implements semi-structured sparsity for the CUTLASS backend. 387 388 389 In this implementation, the specified elements and metadata are stored seprately, 390 in packed and meta respectively. 391 392 When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and 393 sparse_semi_structured_from_dense for conversion to the compressed format. 394 """ 395 396 BACKEND = "cutlass" 397 _DTYPE_SHAPE_CONSTRAINTS = { 398 torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16), 399 torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), 400 torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), 401 torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4), 402 } 403 404 @classmethod 405 def from_dense( 406 cls, original_tensor: torch.Tensor 407 ) -> "SparseSemiStructuredTensorCUTLASS": 408 cls._validate_device_dim_dtype_shape(original_tensor) 409 ( 410 sparse_tensor_cutlass, 411 meta_tensor_cutlass, 412 ) = sparse_semi_structured_from_dense_cutlass(original_tensor) 413 return cls( 414 original_tensor.shape, 415 packed=sparse_tensor_cutlass, 416 meta=meta_tensor_cutlass, 417 packed_t=None, 418 meta_t=None, 419 compressed_swizzled_bitmask=None, 420 requires_grad=original_tensor.requires_grad, 421 ) 422 423 def to_dense(self): 424 assert self.meta is not None and self.packed is not None 425 return ( 426 sparse_semi_structured_to_dense_cutlass( 427 self.packed, 428 self.meta, 429 ) 430 if self.meta.ndim == 2 431 else super().to_dense() 432 ) 433 434 @classmethod 435 def prune_dense_static_sort( 436 cls, original_tensor: torch.Tensor, algorithm="" 437 ) -> "SparseSemiStructuredTensor": 438 """ 439 This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile. 440 441 It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns. 442 The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`. 443 444 Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor. 445 It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed 446 pruned dense tensor. 447 Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively. 448 449 Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern 450 This can be used in the backward pass to mask the gradients. 451 452 [9 1 7 4] [9 0 7 0] 453 [1 2 3 0] [0 2 0 0] 454 [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed 455 [1 2 6 2] [0 0 6 2] -> metadata 456 457 -> pack to transposed CUTLASS -> packed_t 458 semi-structured representation -> metadata_t 459 460 -> compute swizzled bitmask -> compressed_swizzled_bitmask 461 462 463 The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below: 464 ``` 465 from torch.sparse import SparseSemiStructuredTensorCUTLASS 466 from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask 467 468 pruned = _sparse_semi_structured_tile(dense) 469 packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) 470 packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) 471 bitmask = _compute_compressed_swizzled_bitmask(pruned) 472 473 SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask) 474 ``` 475 """ 476 # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag. 477 ( 478 packed, 479 meta, 480 packed_t, 481 meta_t, 482 compressed_swizzled_bitmask, 483 ) = torch._sparse_semi_structured_tile( 484 original_tensor, algorithm=algorithm, use_cutlass=True 485 ) 486 487 return cls( 488 original_tensor.shape, 489 packed=packed, 490 meta=meta, 491 packed_t=packed_t, 492 meta_t=meta_t, 493 compressed_swizzled_bitmask=compressed_swizzled_bitmask, 494 requires_grad=False, 495 ) 496 497 def _mm( 498 self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs 499 ) -> torch.Tensor: 500 if isinstance(B, SparseSemiStructuredTensor): 501 raise ValueError( 502 "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" 503 ) 504 cls_name = self.__class__.__name__ 505 if self.ndim != 2 or B.ndim != 2: 506 raise NotImplementedError( 507 f"`{cls_name}` matmul: Broadcasting is not implemented" 508 ) 509 if self.packed is None or self.meta is None: 510 raise NotImplementedError( 511 f"`{cls_name}` matmul: operation is not supported" 512 ) 513 else: 514 if bias is None: 515 res = torch._sparse_semi_structured_mm(self.packed, self.meta, B) 516 else: 517 res = torch._sparse_semi_structured_addmm( 518 bias, self.packed, self.meta, B 519 ) 520 return res[: self.shape[0]] 521 522 523class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): 524 """ 525 The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor: 526 packed = [ specified elements of original tensor | metadata ] 527 For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements 528 The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t 529 attributes respectively. 530 531 cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well 532 as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes. 533 """ 534 535 BACKEND = "cusparselt" 536 _DTYPE_SHAPE_CONSTRAINTS = { 537 torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16), 538 torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), 539 torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), 540 } 541 542 @classmethod 543 def from_dense( 544 cls, original_tensor: torch.Tensor 545 ) -> "SparseSemiStructuredTensorCUSPARSELT": 546 cls._validate_device_dim_dtype_shape(original_tensor) 547 return cls( 548 shape=original_tensor.shape, 549 packed=torch._cslt_compress(original_tensor), 550 meta=None, 551 packed_t=None, 552 meta_t=None, 553 compressed_swizzled_bitmask=None, 554 fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE, 555 alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID, 556 requires_grad=original_tensor.requires_grad, 557 ) 558 559 @classmethod 560 def prune_dense_static_sort( 561 cls, original_tensor: torch.Tensor, algorithm="" 562 ) -> "SparseSemiStructuredTensor": 563 """ 564 This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata 565 layout and sparse matmul. 566 567 The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor. 568 569 [9 1 7 4] [9 0 7 0] 570 [1 2 3 0] [0 2 0 0] 571 [8 3 5 4] -> prune 4x4 tile -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed 572 [1 2 6 2] [0 0 6 2] 573 574 -> pack to transposed cuSPARSELt -> packed_t 575 semi-structured representation 576 577 -> compute swizzled bitmask -> compressed_swizzled_bitmask 578 579 580 The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below: 581 ``` 582 from torch.sparse import SparseSemiStructuredTensorCUSPARSELT 583 from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask 584 585 pruned = _sparse_semi_structured_tile(dense) 586 packed_cusparselt = torch._cslt_compress(pruned) 587 packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) 588 bitmask = _compute_compressed_swizzled_bitmask(pruned) 589 590 SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask) 591 ``` 592 """ 593 ( 594 packed, 595 meta, 596 packed_t, 597 meta_t, 598 compressed_swizzled_bitmask, 599 ) = torch._sparse_semi_structured_tile( 600 original_tensor, algorithm=algorithm, use_cutlass=False 601 ) 602 603 return cls( 604 original_tensor.shape, 605 packed=packed, 606 meta=meta, 607 packed_t=packed_t, 608 meta_t=meta_t, 609 compressed_swizzled_bitmask=compressed_swizzled_bitmask, 610 requires_grad=False, 611 ) 612 613 def _mm( 614 self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs 615 ) -> torch.Tensor: 616 if isinstance(B, SparseSemiStructuredTensor): 617 raise ValueError( 618 "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" 619 ) 620 if self.ndim != 2 or B.ndim != 2: 621 raise NotImplementedError( 622 f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented" 623 ) 624 if B.dtype != self.dtype: 625 raise NotImplementedError( 626 f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, " 627 f"with A.dtype={self.dtype} and B.dtype={B.dtype}. " 628 "This operation is only supported when A and B have the same data type." 629 ) 630 if bias is not None and bias.dtype != self.dtype: 631 raise NotImplementedError( 632 f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, " 633 "with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. " 634 "This operation is only supported when A, B and C have the same data type." 635 ) 636 if self.packed is None: 637 raise NotImplementedError( 638 f"`{self.__class__.__name__}` matmul: operation is not supported" 639 ) 640 else: 641 res = torch._cslt_sparse_mm( 642 self.packed, 643 B, 644 bias=bias, 645 transpose_result=self.fuse_transpose_cusparselt, 646 alg_id=self.alg_id_cusparselt, 647 ) 648 return res.t() if self.fuse_transpose_cusparselt else res 649