xref: /aosp_15_r20/external/pytorch/torch/sparse/_semi_structured_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3
4import torch
5
6
7__all__ = [
8    "fallback_dispatcher",
9    "semi_sparse_values",
10    "semi_sparse_indices",
11    "semi_sparse_t",
12    "semi_sparse_view",
13    "semi_sparse_detach",
14    "semi_sparse_mm",
15    "semi_sparse_addmm",
16    "semi_sparse_linear",
17]
18
19
20@contextlib.contextmanager
21def no_dispatch():
22    guard = torch._C._DisableTorchDispatch()
23    try:
24        yield
25    finally:
26        del guard
27
28
29def fallback_dispatcher(func, types, args, kwargs):
30    with no_dispatch():
31        return func(*args)
32
33
34def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
35    assert len(args) == 1
36    A = args[0]
37    assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
38    assert A.packed is not None
39    if A.meta is None:
40        m, k = A.shape
41        num_kept_elements = m * k // 2
42        return A.packed[:num_kept_elements:].view(m, -1)
43    else:
44        return A.packed.detach()
45
46
47def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
48    assert len(args) == 1
49    A = args[0]
50    assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
51    assert A.packed is not None
52    if A.meta is None:
53        m, k = A.shape
54        num_kept_elements = m * k // 2
55        metadata = A.packed[num_kept_elements:].view(m, -1)
56        return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
57    else:
58        return A.meta
59
60
61def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
62    assert len(args) == 1
63    self = args[0]
64    assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
65    assert len(self.shape) == 2
66    # Because we cannot go from the compressed representation back to the dense representation currently,
67    # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
68    # is the first or second argument, we expect an even / odd number of calls to transpose respectively.
69    return self.__class__(
70        torch.Size([self.shape[-1], self.shape[0]]),
71        packed=self.packed_t,
72        meta=self.meta_t,
73        packed_t=self.packed,
74        meta_t=self.meta,
75        compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
76        if self.compressed_swizzled_bitmask is not None
77        else None,
78        fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
79        alg_id_cusparselt=args[0].alg_id_cusparselt,
80    )
81
82
83def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
84    assert len(args) == 2
85    self, shape = args
86    if tuple(shape) != self.shape:
87        raise NotImplementedError(
88            f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
89        )
90    return self
91
92
93def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
94    assert len(args) == 1
95    self = args[0]
96    return self.__class__(
97        shape=self.shape,
98        packed=self.packed,
99        meta=self.meta,
100        packed_t=self.packed_t,
101        meta_t=self.meta_t,
102        compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
103        requires_grad=False,
104    )
105
106
107def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
108    assert len(args) == 2
109    A, B = args
110    if A.ndim != 2 or B.ndim != 2:
111        raise NotImplementedError(
112            "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
113        )
114    if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
115        row, col = B.shape
116        B_padded = A._pad_dense_input(B)
117        res = A._mm(B_padded)
118        return res[:, :col]
119    else:
120        B_t = B.t()
121        assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
122        row, col = A.shape
123        A_padded = B._pad_dense_input(A)
124        res = B_t._mm(A_padded.t()).t()
125        return res[:row, :]
126
127
128def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
129    assert len(args) == 3
130    bias, A, B = args
131    if A.ndim != 2 or B.ndim != 2:
132        raise NotImplementedError(
133            "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
134        )
135    if bias.ndim != 1:
136        raise NotImplementedError(
137            f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
138        )
139    if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
140        raise NotImplementedError(
141            "`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
142        )
143    B_t = B.t()
144    assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
145    row, col = A.shape
146    A_padded = B_t._pad_dense_input(A)
147    result = B_t._mm(A_padded.t(), bias=bias).t()
148    return result[:row, :]
149
150
151def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
152    assert len(args) in [2, 3]
153    A, B = args[:2]
154    bias = args[2] if len(args) == 3 else None
155
156    shape = A.shape
157    A_2d = A.view(-1, shape[-1])
158
159    if bias is None:
160        res = A_2d @ B.t()
161    else:
162        res = semi_sparse_addmm(
163            func=None,
164            types=None,
165            args=[bias, A_2d, B.t()],
166        )
167
168    return res.view(*shape[:-1], -1)
169