1# mypy: allow-untyped-defs 2import contextlib 3 4import torch 5 6 7__all__ = [ 8 "fallback_dispatcher", 9 "semi_sparse_values", 10 "semi_sparse_indices", 11 "semi_sparse_t", 12 "semi_sparse_view", 13 "semi_sparse_detach", 14 "semi_sparse_mm", 15 "semi_sparse_addmm", 16 "semi_sparse_linear", 17] 18 19 20@contextlib.contextmanager 21def no_dispatch(): 22 guard = torch._C._DisableTorchDispatch() 23 try: 24 yield 25 finally: 26 del guard 27 28 29def fallback_dispatcher(func, types, args, kwargs): 30 with no_dispatch(): 31 return func(*args) 32 33 34def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor: 35 assert len(args) == 1 36 A = args[0] 37 assert isinstance(A, torch.sparse.SparseSemiStructuredTensor) 38 assert A.packed is not None 39 if A.meta is None: 40 m, k = A.shape 41 num_kept_elements = m * k // 2 42 return A.packed[:num_kept_elements:].view(m, -1) 43 else: 44 return A.packed.detach() 45 46 47def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor: 48 assert len(args) == 1 49 A = args[0] 50 assert isinstance(A, torch.sparse.SparseSemiStructuredTensor) 51 assert A.packed is not None 52 if A.meta is None: 53 m, k = A.shape 54 num_kept_elements = m * k // 2 55 metadata = A.packed[num_kept_elements:].view(m, -1) 56 return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16) 57 else: 58 return A.meta 59 60 61def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor: 62 assert len(args) == 1 63 self = args[0] 64 assert isinstance(self, torch.sparse.SparseSemiStructuredTensor) 65 assert len(self.shape) == 2 66 # Because we cannot go from the compressed representation back to the dense representation currently, 67 # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix 68 # is the first or second argument, we expect an even / odd number of calls to transpose respectively. 69 return self.__class__( 70 torch.Size([self.shape[-1], self.shape[0]]), 71 packed=self.packed_t, 72 meta=self.meta_t, 73 packed_t=self.packed, 74 meta_t=self.meta, 75 compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1) 76 if self.compressed_swizzled_bitmask is not None 77 else None, 78 fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt, 79 alg_id_cusparselt=args[0].alg_id_cusparselt, 80 ) 81 82 83def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor: 84 assert len(args) == 2 85 self, shape = args 86 if tuple(shape) != self.shape: 87 raise NotImplementedError( 88 f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})" 89 ) 90 return self 91 92 93def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor: 94 assert len(args) == 1 95 self = args[0] 96 return self.__class__( 97 shape=self.shape, 98 packed=self.packed, 99 meta=self.meta, 100 packed_t=self.packed_t, 101 meta_t=self.meta_t, 102 compressed_swizzled_bitmask=self.compressed_swizzled_bitmask, 103 requires_grad=False, 104 ) 105 106 107def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor: 108 assert len(args) == 2 109 A, B = args 110 if A.ndim != 2 or B.ndim != 2: 111 raise NotImplementedError( 112 "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented" 113 ) 114 if isinstance(A, torch.sparse.SparseSemiStructuredTensor): 115 row, col = B.shape 116 B_padded = A._pad_dense_input(B) 117 res = A._mm(B_padded) 118 return res[:, :col] 119 else: 120 B_t = B.t() 121 assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor) 122 row, col = A.shape 123 A_padded = B._pad_dense_input(A) 124 res = B_t._mm(A_padded.t()).t() 125 return res[:row, :] 126 127 128def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor: 129 assert len(args) == 3 130 bias, A, B = args 131 if A.ndim != 2 or B.ndim != 2: 132 raise NotImplementedError( 133 "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented" 134 ) 135 if bias.ndim != 1: 136 raise NotImplementedError( 137 f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}" 138 ) 139 if isinstance(A, torch.sparse.SparseSemiStructuredTensor): 140 raise NotImplementedError( 141 "`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse" 142 ) 143 B_t = B.t() 144 assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor) 145 row, col = A.shape 146 A_padded = B_t._pad_dense_input(A) 147 result = B_t._mm(A_padded.t(), bias=bias).t() 148 return result[:row, :] 149 150 151def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor: 152 assert len(args) in [2, 3] 153 A, B = args[:2] 154 bias = args[2] if len(args) == 3 else None 155 156 shape = A.shape 157 A_2d = A.view(-1, shape[-1]) 158 159 if bias is None: 160 res = A_2d @ B.t() 161 else: 162 res = semi_sparse_addmm( 163 func=None, 164 types=None, 165 args=[bias, A_2d, B.t()], 166 ) 167 168 return res.view(*shape[:-1], -1) 169