xref: /aosp_15_r20/external/pytorch/torch/sparse/semi_structured.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import warnings
3from collections import namedtuple
4from typing import Any, Callable, Dict, List, Optional, Tuple
5
6import torch
7from torch.sparse._semi_structured_conversions import (
8    sparse_semi_structured_from_dense_cutlass,
9    sparse_semi_structured_to_dense_cutlass,
10)
11from torch.sparse._semi_structured_ops import (
12    fallback_dispatcher,
13    semi_sparse_addmm,
14    semi_sparse_detach,
15    semi_sparse_indices,
16    semi_sparse_linear,
17    semi_sparse_mm,
18    semi_sparse_t,
19    semi_sparse_values,
20    semi_sparse_view,
21)
22
23
24__all__ = [
25    "SparseSemiStructuredTensor",
26    "SparseSemiStructuredTensorCUTLASS",
27    "SparseSemiStructuredTensorCUSPARSELT",
28    "to_sparse_semi_structured",
29]
30
31_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
32    "_SEMI_STRUCTURED_SPARSE_CONFIG",
33    "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
34)
35
36
37class SparseSemiStructuredTensor(torch.Tensor):
38    """
39    This class implementes semi-structured sparsity as a Tensor subclass.
40
41    Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
42    depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
43    structured sparsity.
44
45    There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
46    This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
47    and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
48    Note that as such, this class cannot be insantiated directly.
49
50    -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
51    - `def from_dense()` - backend specific compression routines
52    - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
53    """
54
55    _DEFAULT_ALG_ID: int = 0
56    _DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
57    _FORCE_CUTLASS: bool = True
58    _FUSE_TRANSPOSE: bool = False
59    _PROTOTYPE_WARNING_SHOWN: bool = False
60
61    BACKEND: str
62    SPARSE_DISPATCH: Dict[Callable, Callable]
63
64    packed: Optional[torch.Tensor]
65    meta: Optional[torch.Tensor]
66    packed_t: Optional[torch.Tensor]
67    meta_t: Optional[torch.Tensor]
68    compressed_swizzled_bitmask: Optional[torch.Tensor]
69    fuse_transpose_cusparselt: bool
70    alg_id_cusparselt: int
71
72    __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"]
73
74    @staticmethod
75    def __new__(  # noqa: PYI034
76        cls,
77        shape: torch.Size,
78        packed: Optional[torch.Tensor],
79        meta: Optional[torch.Tensor],
80        packed_t: Optional[torch.Tensor],
81        meta_t: Optional[torch.Tensor],
82        compressed_swizzled_bitmask: Optional[torch.Tensor],
83        fuse_transpose_cusparselt: bool = False,
84        alg_id_cusparselt: int = 0,
85        requires_grad: bool = False,
86    ):
87        """
88        Create a new instance of the tensor subclass from the compressed sparse representation.
89
90        We have the option to create the subclass with the compressed representations of both X and X', for training.
91        For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
92
93        Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
94
95        Args:
96            shape: The shape of the original dense tensor
97            packed: The compressed representation of the original dense tensor
98            meta: The metadata of the original dense tensor, if it is stored separately
99            packed_t: The compressed representation of the transposed original dense tensor
100            meta_t: The metadata of the transposed original dense tensor, if it is stored separately
101            compressed_swizzled_bitmask: The masks used by the CUTLASS backend to determine which threads should
102                                         participate in the computation. Used for pointwise ops.
103            fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
104                                       with a matmul, which is useful in the case of 2:4 sparse training.
105            alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
106
107        Returns:
108            torch.Tensor: A torch.Tensor wrapper subclass.
109
110        Raises:
111            ValueError: If all of the tensor arguments are None.
112        """
113        if not cls._PROTOTYPE_WARNING_SHOWN:
114            warnings.warn(
115                (
116                    "The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
117                    "and will change in the near future. Please open a Github issue "
118                    "for features requests and see our documentation on the torch.sparse "
119                    "module for further information about the project."
120                ),
121                UserWarning,
122            )
123            cls._PROTOTYPE_WARNING_SHOWN = True
124
125            # Because this only runs onces, we also load the dispatch table here as well.
126            # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
127            # But this is useful since it allows users to overload the dispatch table for debugging / testing.
128            cls._load_dispatch_table()
129
130            # we can also register the classes with dynamo when the warning is shown.
131            torch._dynamo.allow_in_graph(cls)
132
133        if packed is not None:
134            previous_tensor = packed
135        elif packed_t is not None:
136            previous_tensor = packed_t
137        else:
138            raise ValueError("At least one of packed or packed_t must be provided")
139
140        kwargs = {
141            "device": previous_tensor.device,
142            "dtype": previous_tensor.dtype,
143            "layout": previous_tensor.layout,
144            "requires_grad": requires_grad,
145        }
146        tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)  # type: ignore[attr-defined]
147
148        tensor.packed = packed
149        tensor.meta = meta
150        tensor.packed_t = packed_t
151        tensor.meta_t = meta_t
152        tensor.compressed_swizzled_bitmask = compressed_swizzled_bitmask
153        tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
154        tensor.alg_id_cusparselt = alg_id_cusparselt
155        return tensor
156
157    def __repr__(self) -> str:  # type: ignore[override]
158        assert hasattr(self, "shape")
159        return f"{self.__class__.__name__}(shape={self.shape})"
160
161    def __tensor_flatten__(
162        self,
163    ) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
164        inner_tensors = list(
165            filter(lambda x: getattr(self, x) is not None, self.__slots__)
166        )
167        tensor_meta = (
168            self.shape,
169            self.fuse_transpose_cusparselt,
170            self.alg_id_cusparselt,
171            self.requires_grad,
172        )
173        return inner_tensors, tensor_meta
174
175    @classmethod
176    def __tensor_unflatten__(
177        cls,
178        inner_tensors,
179        tensor_meta: Tuple[torch.Size, bool, int, bool],
180        outer_size,
181        outer_stride,
182    ) -> torch.Tensor:
183        shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
184        return cls(
185            shape=shape,
186            packed=inner_tensors.get("packed", None),
187            meta=inner_tensors.get("meta", None),
188            packed_t=inner_tensors.get("packed_t", None),
189            meta_t=inner_tensors.get("meta_t", None),
190            compressed_swizzled_bitmask=inner_tensors.get(
191                "compressed_swizzled_bitmask", None
192            ),
193            fuse_transpose_cusparselt=fuse_transpose_cusparselt,
194            alg_id_cusparselt=alg_id_cusparselt,
195            requires_grad=requires_grad,
196        )
197
198    __torch_function__ = torch._C._disabled_torch_function_impl
199
200    @classmethod
201    def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
202        if func._overloadpacket not in cls.SPARSE_DISPATCH:
203            raise NotImplementedError(
204                f"{cls.__name__} only supports a specific set of operations, "
205                f"can't perform requested op ({func.__name__})"
206            )
207        return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
208
209    @classmethod
210    def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
211        """
212        Loads the op overload sparse dispatch table for the current class.
213        """
214        if getattr(cls, "SPARSE_DISPATCH", None) is None:
215            cls.SPARSE_DISPATCH = {
216                torch.ops.aten.values: semi_sparse_values,
217                torch.ops.aten.indices: semi_sparse_indices,
218                torch.ops.aten.is_same_size: fallback_dispatcher,
219                torch.ops.aten.detach_: fallback_dispatcher,
220                torch.ops.aten.detach: semi_sparse_detach,
221                torch.ops.aten.t: semi_sparse_t,
222                torch.ops.aten.view: semi_sparse_view,
223                torch.ops.aten.mm: semi_sparse_mm,
224                torch.ops.aten.matmul: semi_sparse_mm,
225                torch.ops.aten.addmm: semi_sparse_addmm,
226                torch.ops.aten.linear: semi_sparse_linear,
227                torch.ops.aten._to_copy: fallback_dispatcher,
228            }
229            if custom_dispatch_table is not None:
230                cls.SPARSE_DISPATCH.update(custom_dispatch_table)
231
232    @classmethod
233    def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None:
234        """
235        Assert that the given tensor is valid for semi-structured sparse compression.
236        """
237        # check device
238        if not original_tensor.is_cuda:
239            raise RuntimeError(
240                f"Error original_tensor.device= {original_tensor.device} is not supported! "
241                "Only CUDA tensors are currently supported."
242            )
243
244        # check dim
245        if original_tensor.dim() != 2:
246            raise RuntimeError(
247                f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
248                "Only 2d tensors are currently supported."
249            )
250
251        # check contiguous
252        if not original_tensor.is_contiguous():
253            raise RuntimeError(
254                "Error original_tensor is not contiguous!"
255                "Only contiguous tensors are currently supported."
256            )
257
258        # check dtype
259        if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
260            raise RuntimeError(
261                f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
262                "dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
263            )
264
265        # check shape
266        m, n = original_tensor.shape
267        min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
268        min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
269        if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
270            # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
271            raise RuntimeError(
272                f"Error original_tensor.shape {original_tensor.shape} is not supported! "
273                f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
274            )
275
276    @classmethod
277    def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
278        """
279        Calculates padding for dense tensor and pads tensor if necessary.
280        If padding is not required, this function returns the original tensor.
281        """
282        # only 2d matmul
283        assert dense_input.dim() == 2
284
285        # check shape
286        m, n = dense_input.shape
287        min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
288        min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
289
290        # calculate padding
291        to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
292        to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
293        if to_pad_m or to_pad_n:
294            return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
295        else:
296            return dense_input
297
298    def to_dense(self):
299        col = self.shape[-1]
300        return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
301
302    @classmethod
303    def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor":
304        raise NotImplementedError
305
306    def _mm(
307        self,
308        B: torch.Tensor,
309        *,
310        bias: Optional[torch.Tensor] = None,
311        **kwargs,
312    ) -> torch.Tensor:
313        raise NotImplementedError
314
315
316def to_sparse_semi_structured(
317    original_tensor: torch.Tensor,
318    transposed: bool = False,
319) -> SparseSemiStructuredTensor:
320    """
321    This function converts a dense tensor into a sparse semi-structured tensor.
322    It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
323
324    This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
325    We currently only support semi-structured sparse tensors for 2d CUDA tensors.
326    Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
327    `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
328
329    Args:
330        original_tensor (Tensor): the dense tensor to convert
331        transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
332    Returns:
333        SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
334    Raises:
335        None
336    Example:
337        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
338        >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
339        tensor([[0., 0., 1.,  ..., 0., 1., 1.],
340                [0., 0., 1.,  ..., 0., 1., 1.],
341                [0., 0., 1.,  ..., 0., 1., 1.],
342                ...,
343                [0., 0., 1.,  ..., 0., 1., 1.],
344                [0., 0., 1.,  ..., 0., 1., 1.],
345                [0., 0., 1.,  ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
346        >>> A_sparse = to_sparse_semi_structured(A)
347        SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
348        >>> A_sparse.values()
349        tensor([[1., 1., 1.,  ..., 1., 1., 1.],
350                [1., 1., 1.,  ..., 1., 1., 1.],
351                [1., 1., 1.,  ..., 1., 1., 1.],
352                ...,
353                [1., 1., 1.,  ..., 1., 1., 1.],
354                [1., 1., 1.,  ..., 1., 1., 1.],
355                [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
356        >>> A_sparse.indices()
357        tensor([[-4370, -4370, -4370,  ..., -4370, -4370, -4370],
358                [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
359                [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
360                ...,
361                [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
362                [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
363                [-4370, -4370, -4370,  ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
364    """
365    if transposed:
366        warnings.warn(
367            "Setting transpose from `to_sparse_semi_structured` is deprecated "
368            "and will be removed in a future release. "
369            "`SparseSemiStructuredTensor` only support contiguous input tensors.",
370            FutureWarning,
371            stacklevel=2,
372        )
373
374    # set from _FORCE_CUTLASS flag
375    SPARSE_SUBCLASS = (
376        torch.sparse.SparseSemiStructuredTensorCUTLASS
377        if SparseSemiStructuredTensor._FORCE_CUTLASS
378        else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
379    )
380
381    return SPARSE_SUBCLASS.from_dense(original_tensor)
382
383
384class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
385    """
386    This class implements semi-structured sparsity for the CUTLASS backend.
387
388
389    In this implementation, the specified elements and metadata are stored seprately,
390    in packed and meta respectively.
391
392    When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
393    sparse_semi_structured_from_dense for conversion to the compressed format.
394    """
395
396    BACKEND = "cutlass"
397    _DTYPE_SHAPE_CONSTRAINTS = {
398        torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
399        torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
400        torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
401        torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
402    }
403
404    @classmethod
405    def from_dense(
406        cls, original_tensor: torch.Tensor
407    ) -> "SparseSemiStructuredTensorCUTLASS":
408        cls._validate_device_dim_dtype_shape(original_tensor)
409        (
410            sparse_tensor_cutlass,
411            meta_tensor_cutlass,
412        ) = sparse_semi_structured_from_dense_cutlass(original_tensor)
413        return cls(
414            original_tensor.shape,
415            packed=sparse_tensor_cutlass,
416            meta=meta_tensor_cutlass,
417            packed_t=None,
418            meta_t=None,
419            compressed_swizzled_bitmask=None,
420            requires_grad=original_tensor.requires_grad,
421        )
422
423    def to_dense(self):
424        assert self.meta is not None and self.packed is not None
425        return (
426            sparse_semi_structured_to_dense_cutlass(
427                self.packed,
428                self.meta,
429            )
430            if self.meta.ndim == 2
431            else super().to_dense()
432        )
433
434    @classmethod
435    def prune_dense_static_sort(
436        cls, original_tensor: torch.Tensor, algorithm=""
437    ) -> "SparseSemiStructuredTensor":
438        """
439        This function takes in a unpruned dense tensor and runs a (branchless) static sort across a 4x4 tile.
440
441        It greedily picks the largest values in the tile, upholding the 2:4 sparsity constraint across both rows and columns.
442        The algorithm used to prune the matrix is implemented in `_sparse_semi_structured_tile`.
443
444        Then it creates the packed and meta tensors for the compressed sparse representation of the pruned dense tensor.
445        It also calculates the packed_t and meta_t tensors for the compressed sparse representation of the transposed
446        pruned dense tensor.
447        Since we cannot transpose the compressed representations, we store both for the fw/bw pass respectively.
448
449        Finally, this function also computes a compressed swizzled bitmask that encodes the sparsity pattern
450        This can be used in the backward pass to mask the gradients.
451
452        [9 1 7 4]                       [9 0 7 0]
453        [1 2 3 0]                       [0 2 0 0]
454        [8 3 5 4] -> prune 4x4 tile  -> [8 0 0 4] -> pack to CUTLASS semi-structured -> packed
455        [1 2 6 2]                       [0 0 6 2]                                    -> metadata
456
457                                                  -> pack to transposed CUTLASS      -> packed_t
458                                                     semi-structured representation  -> metadata_t
459
460                                                  -> compute swizzled bitmask        -> compressed_swizzled_bitmask
461
462
463        The equivalent PyTorch code to create the same five outputs from the dense tensor can be found below:
464        ```
465        from torch.sparse import SparseSemiStructuredTensorCUTLASS
466        from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
467
468        pruned = _sparse_semi_structured_tile(dense)
469        packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
470        packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
471        bitmask = _compute_compressed_swizzled_bitmask(pruned)
472
473        SparseSemiStructuredTensorCUTLASS(dense.shape, packed_cutlass, meta_cutlass, packed_t_cutlass, meta_t_cutlass, bitmask)
474        ```
475        """
476        # We can either pack to the CUTLASS or cuSPARSELt representation, depending on the use_cutlass flag.
477        (
478            packed,
479            meta,
480            packed_t,
481            meta_t,
482            compressed_swizzled_bitmask,
483        ) = torch._sparse_semi_structured_tile(
484            original_tensor, algorithm=algorithm, use_cutlass=True
485        )
486
487        return cls(
488            original_tensor.shape,
489            packed=packed,
490            meta=meta,
491            packed_t=packed_t,
492            meta_t=meta_t,
493            compressed_swizzled_bitmask=compressed_swizzled_bitmask,
494            requires_grad=False,
495        )
496
497    def _mm(
498        self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
499    ) -> torch.Tensor:
500        if isinstance(B, SparseSemiStructuredTensor):
501            raise ValueError(
502                "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
503            )
504        cls_name = self.__class__.__name__
505        if self.ndim != 2 or B.ndim != 2:
506            raise NotImplementedError(
507                f"`{cls_name}` matmul: Broadcasting is not implemented"
508            )
509        if self.packed is None or self.meta is None:
510            raise NotImplementedError(
511                f"`{cls_name}` matmul: operation is not supported"
512            )
513        else:
514            if bias is None:
515                res = torch._sparse_semi_structured_mm(self.packed, self.meta, B)
516            else:
517                res = torch._sparse_semi_structured_addmm(
518                    bias, self.packed, self.meta, B
519                )
520            return res[: self.shape[0]]
521
522
523class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
524    """
525    The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
526    packed = [ specified elements of original tensor | metadata ]
527    For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
528    The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
529    attributes respectively.
530
531    cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
532    as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
533    """
534
535    BACKEND = "cusparselt"
536    _DTYPE_SHAPE_CONSTRAINTS = {
537        torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
538        torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
539        torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
540    }
541
542    @classmethod
543    def from_dense(
544        cls, original_tensor: torch.Tensor
545    ) -> "SparseSemiStructuredTensorCUSPARSELT":
546        cls._validate_device_dim_dtype_shape(original_tensor)
547        return cls(
548            shape=original_tensor.shape,
549            packed=torch._cslt_compress(original_tensor),
550            meta=None,
551            packed_t=None,
552            meta_t=None,
553            compressed_swizzled_bitmask=None,
554            fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
555            alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
556            requires_grad=original_tensor.requires_grad,
557        )
558
559    @classmethod
560    def prune_dense_static_sort(
561        cls, original_tensor: torch.Tensor, algorithm=""
562    ) -> "SparseSemiStructuredTensor":
563        """
564        This function does the same thing as described in SparseSemiStructuredCUTLASS, but uses the cuSPASRELt metadata
565        layout and sparse matmul.
566
567        The only functional difference is that cuSPARSELt stores `metadata` and `packed` together into a single tensor.
568
569        [9 1 7 4]                       [9 0 7 0]
570        [1 2 3 0]                       [0 2 0 0]
571        [8 3 5 4] -> prune 4x4 tile  -> [8 0 0 4] -> pack to cuSPARSELT semi-structured -> packed
572        [1 2 6 2]                       [0 0 6 2]
573
574                                                  -> pack to transposed cuSPARSELt      -> packed_t
575                                                     semi-structured representation
576
577                                                  -> compute swizzled bitmask           -> compressed_swizzled_bitmask
578
579
580        The equivalent PyTorch code to create the same three outputs from the dense tensor can be found below:
581        ```
582        from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
583        from torch.sparse._semi_structured_conversions import _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask
584
585        pruned = _sparse_semi_structured_tile(dense)
586        packed_cusparselt = torch._cslt_compress(pruned)
587        packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
588        bitmask = _compute_compressed_swizzled_bitmask(pruned)
589
590        SparseSemiStructuredTensorCUSPARSELT(dense.shape, packed_cutlass, None, packed_t_cutlass, None, bitmask)
591        ```
592        """
593        (
594            packed,
595            meta,
596            packed_t,
597            meta_t,
598            compressed_swizzled_bitmask,
599        ) = torch._sparse_semi_structured_tile(
600            original_tensor, algorithm=algorithm, use_cutlass=False
601        )
602
603        return cls(
604            original_tensor.shape,
605            packed=packed,
606            meta=meta,
607            packed_t=packed_t,
608            meta_t=meta_t,
609            compressed_swizzled_bitmask=compressed_swizzled_bitmask,
610            requires_grad=False,
611        )
612
613    def _mm(
614        self, B: torch.Tensor, *, bias: Optional[torch.Tensor] = None, **kwargs
615    ) -> torch.Tensor:
616        if isinstance(B, SparseSemiStructuredTensor):
617            raise ValueError(
618                "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
619            )
620        if self.ndim != 2 or B.ndim != 2:
621            raise NotImplementedError(
622                f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
623            )
624        if B.dtype != self.dtype:
625            raise NotImplementedError(
626                f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
627                f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
628                "This operation is only supported when A and B have the same data type."
629            )
630        if bias is not None and bias.dtype != self.dtype:
631            raise NotImplementedError(
632                f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
633                "with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
634                "This operation is only supported when A, B and C have the same data type."
635            )
636        if self.packed is None:
637            raise NotImplementedError(
638                f"`{self.__class__.__name__}` matmul: operation is not supported"
639            )
640        else:
641            res = torch._cslt_sparse_mm(
642                self.packed,
643                B,
644                bias=bias,
645                transpose_result=self.fuse_transpose_cusparselt,
646                alg_id=self.alg_id_cusparselt,
647            )
648            return res.t() if self.fuse_transpose_cusparselt else res
649