1 #include <ATen/ScalarOps.h>
2 #include <ATen/Functions.h>
3 #include <ATen/Tensor.h>
4 #include <ATen/autocast_mode.h>
5 #include <c10/cuda/CUDAGuard.h>
6 #include <ATen/ATen.h>
7 #include <ATen/core/Tensor.h>
8 #include <ATen/cuda/CUDAUtils.h>
9 #include <ATen/Dispatch.h>
10
11 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
12 #else
13 #include <ATen/native/sparse/cuda/ComputeSparseTile.h>
14 #include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
15 #include <cuda_runtime.h>
16 #endif
17
18 namespace at::native {
19
20 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
21 #else
22 struct MetadataCuSparseLt {
23 // Format used by cuSparseLt
24 // This is based on reverse-engineering, for a visual illustration:
25 // https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g29afe95bda8_0_0
26 static constexpr int kStrideBlock32x32 = (32 * 32) / (sizeof(ElementInputE) * 8);
27
28 ElementInputE* _meta;
29 ElementInputE* _meta_trans;
30 int64_t _rows;
31 int64_t _cols;
32
33 static int64_t getMetadataSize(int rows, int cols)
34 {
35 TORCH_CHECK(rows % 128 == 0 && cols % 128 == 0, "Only supports rows/cols multiples of 128");
36 // 1 bit per dense value
37 return (rows * cols) / (8 * sizeof(ElementInputE));
38 }
39
40 // < return value of the function, packed, packed_meta >
41 static std::tuple<Tensor, Tensor, Tensor> create_compressed_representation(int rows, int cols, at::Tensor const& like)
42 {
43 TORCH_CHECK(
44 like.scalar_type() == at::ScalarType::Half ||
45 like.scalar_type() == at::ScalarType::BFloat16);
46 constexpr int kBytesPerScalar = 2;
47 int64_t data_scalars = rows * cutlass::ceil_div(cols, 2);
48 int64_t meta_scalars = getMetadataSize(rows, cols);
49
50 at::Tensor storage = at::empty(
51 {(data_scalars + meta_scalars)},
52 at::TensorOptions().device(like.device()).dtype(like.dtype()));
53
54 using at::indexing::Slice;
55 using at::indexing::None;
56 at::Tensor packed = storage.index({Slice(None, data_scalars)})
57 .view({rows, cutlass::ceil_div(cols, 2)});
58 at::Tensor metadata = storage.index({Slice(data_scalars, None)});
59 // TODO: Cast metadata to Short
60 static_assert(kBytesPerScalar == 2, "or modify the last dim below");
61 metadata = metadata.view({rows / 128, cols / 32, 256});
62 return std::make_tuple(storage, packed, metadata);
63 }
64
65 MetadataCuSparseLt(at::Tensor metaN, at::Tensor metaT, int rows, int cols) {
66 _meta = (ElementInputE*)metaN.data_ptr();
67 _meta_trans = (ElementInputE*)metaT.data_ptr();
68 _rows = rows;
69 _cols = cols;
70 }
71 CUTLASS_HOST_DEVICE
72 static int64_t _get_meta_offset(
73 int warp_row,
74 int thread_row,
75 int warp_col,
76 int thread_col,
77 int totalRows) {
78 int64_t offset = 0;
79 // warp-level: Find the 128x64 tile
80 offset += (warp_row / 128) * (kStrideBlock32x32 * 8);
81 offset += (warp_col / 64) * (kStrideBlock32x32 * 8) * (totalRows / 128);
82 // Find the 32x32 tile inside
83 offset += (((warp_row + thread_row) % 128) / 32) * kStrideBlock32x32;
84 offset += (((warp_col + thread_col) % 64) / 32) * (kStrideBlock32x32 * 4);
85 // Inside the 32x32 tile
86 offset += (warp_row % 32) * 2;
87 // Top/bottom 16x16 tile
88 offset += ((thread_row % 32) / 16) * 4;
89 // Left/right 16x16 tile
90 offset += ((thread_col % 32) / 16) * 2;
91 return offset;
92 }
93 CUTLASS_HOST_DEVICE
94 ElementInputE* get_metaN(
95 int warp_row,
96 int thread_row,
97 int warp_col,
98 int thread_col) const {
99 return _meta +
100 _get_meta_offset(warp_row, thread_row, warp_col, thread_col, _rows);
101 }
102 CUTLASS_HOST_DEVICE
103 ElementInputE* get_metaT(
104 int warp_row,
105 int thread_row,
106 int warp_col,
107 int thread_col) const {
108 return _meta_trans +
109 _get_meta_offset(warp_col, thread_col, warp_row, thread_row, _cols);
110 }
111 };
112
113 struct MetadataCutlass {
114 // Layout needed to run 2:4 gemms in CUTLASS
115 // There is basically a hardware specific value for every
116 // 32x32 dense tile (1024 bits). Then these tiles are
117 // stored in a Column-Major fashion
118 ElementInputE* _meta;
119 ElementInputE* _meta_trans;
120 int64_t _meta_reordered_sy;
121 int64_t _meta_trans_reordered_sx;
122
123 static std::tuple<
124 at::Tensor, // return value of the function
125 at::Tensor, // packed
126 at::Tensor // packed_meta
127 >
128 create_compressed_representation(int rows, int cols, at::Tensor const& like) {
129 TORCH_CHECK(
130 like.scalar_type() == at::ScalarType::Half ||
131 like.scalar_type() == at::ScalarType::BFloat16);
132 auto roundedx = cutlass::round_up(rows, kWarpX);
133 auto roundedy = cutlass::round_up(cols, kWarpY);
134
135 // NB: Writing to `packed` tensors in transposed manner
136 at::Tensor packed =
137 at::empty({roundedx, cutlass::ceil_div(roundedy, 2)}, like.options());
138 at::Tensor packed_meta = at::empty(
139 {roundedx * roundedy / 16},
140 like.options().dtype(at::ScalarType::Short))
141 .view({roundedy / 32, roundedx, 2})
142 .permute({1, 2, 0});
143 return std::make_tuple(packed, packed, packed_meta);
144 }
145 MetadataCutlass(at::Tensor metaN, at::Tensor metaT, int rows, int cols) {
146 _meta = (ElementInputE*)metaN.data_ptr();
147 _meta_reordered_sy = metaN.stride(2);
148 _meta_trans = (ElementInputE*)metaT.data_ptr();
149 _meta_trans_reordered_sx = metaT.stride(2);
150 }
151 CUTLASS_HOST_DEVICE
152 int64_t _get_meta_offset(
153 int warp_row,
154 int thread_row,
155 int warp_col,
156 int thread_col,
157 int64_t stride) const {
158 int64_t offset = 0;
159 offset += warp_row * 2 + (warp_col / 32) * stride;
160 // A single warp is 32x64. The right 32x32 tile is at a different position
161 offset += 64 * (thread_row / 32);
162 offset += (thread_col / 32) * stride;
163 // Top/bottom 16x16 tile
164 offset += ((thread_row % 32) / 16) * 4;
165 // Left/right 16x16 tile
166 offset += ((thread_col % 32) / 16) * 2;
167 return offset;
168 }
169 CUTLASS_HOST_DEVICE
170 ElementInputE* get_metaN(
171 int warp_row,
172 int thread_row,
173 int warp_col,
174 int thread_col) const {
175 return _meta +
176 _get_meta_offset(
177 warp_row, thread_row, warp_col, thread_col, _meta_reordered_sy);
178 }
179 CUTLASS_HOST_DEVICE
180 ElementInputE* get_metaT(
181 int warp_row,
182 int thread_row,
183 int warp_col,
184 int thread_col) const {
185 return _meta_trans +
186 _get_meta_offset(
187 warp_col,
188 thread_col,
189 warp_row,
190 thread_row,
191 _meta_trans_reordered_sx);
192 }
193 };
194
195 template <typename KT, typename Metadata, typename Algorithm>
196 __global__ void __launch_bounds__(32 /* num_threads */, 20)
197 sparse_semi_structured_tile_kernel(
198 typename KT::Params p,
199 Metadata metadata,
200 Algorithm algo) {
201 KT::sparse_semi_structured_tile_kernel(p, metadata, algo);
202 }
203
204 template <typename Element, typename MetadataFormat>
205 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> sparse_semi_structured_tile_typed(
206 const at::Tensor input,
207 std::string algorithm)
208 {
209 using KT = KernelTypes<Element>;
210 std::optional<at::cuda::CUDAGuard> device_guard;
211 if (!input.is_meta()) {
212 device_guard.emplace(input.device());
213 }
214
215 TORCH_CHECK(input.dim() == 2, "Can only sparsify 2d tensors");
216 TORCH_CHECK(
217 input.stride(1) == 1,
218 "Can only sparsify contiguous tensors. Sparsify the transpose otherwise.");
219
220 auto rows = input.size(0);
221 auto cols = input.size(1);
222
223 auto [compressed, packed, packed_meta_reordered] =
224 MetadataFormat::create_compressed_representation(rows, cols, input);
225 auto [compressed_trans, packed_trans, packed_trans_meta_reordered] =
226 MetadataFormat::create_compressed_representation(cols, rows, input);
227 TORCH_CHECK(
228 input.size(1) % 32 == 0, "Number of cols should be multiple of 32");
229
230 typename KT::Params p;
231 p.input = (Element const*)input.data_ptr();
232 p.input_s0 = input.stride(0);
233 p.input_dim0 = input.size(0);
234 p.input_dim1 = input.size(1);
235
236 p.packed = (Element*)packed.data_ptr();
237 p.packed_stride = packed.stride(0);
238 p.packed_trans = (Element*)packed_trans.data_ptr();
239 p.packed_trans_stride = packed_trans.stride(0);
240
241 MetadataFormat metadata = MetadataFormat(
242 packed_meta_reordered, packed_trans_meta_reordered, rows, cols);
243 at::Tensor threads_masks = at::empty(
244 {p.getBlocksGrid().x * p.getThreadsGrid().x,
245 p.getBlocksGrid().y * p.getThreadsGrid().y,
246 sizeof(p.threads_masks[0])},
247 input.options().dtype(at::ScalarType::Byte));
248 p.threads_masks = (uint64_t*)threads_masks.data_ptr();
249
250 bool kernel_launched = false;
251 auto launchKernel = [&](auto algo, std::string const& algo_name) {
252 if (algo_name == algorithm) {
253 kernel_launched = true;
254 if (input.is_meta()) {
255 return;
256 }
257 size_t smem_bytes = 0;
258 sparse_semi_structured_tile_kernel<KT>
259 <<<p.getBlocksGrid(),
260 p.getThreadsGrid(),
261 smem_bytes,
262 at::cuda::getCurrentCUDAStream()>>>(p, metadata, algo);
263 }
264 };
265 named_algorithms(launchKernel);
266 TORCH_CHECK(kernel_launched, "Unknown algorithm \"", algorithm, "\"");
267 C10_CUDA_KERNEL_LAUNCH_CHECK();
268 return std::make_tuple(
269 compressed,
270 packed_meta_reordered,
271 compressed_trans,
272 packed_trans_meta_reordered,
273 threads_masks);
274 }
275 #endif
276
277 // <packed, packed_meta_reordered, packed_trans, packed_trans_meta_reorderd, threads_masks>
_sparse_semi_structured_tile(const Tensor & input,c10::string_view algorithm,bool use_cutlass)278 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _sparse_semi_structured_tile(
279 const Tensor& input,
280 c10::string_view algorithm,
281 bool use_cutlass)
282 {
283 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
284 AT_ERROR("_sparse_semi_structured_tile: not supported");
285 return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{});
286 #else
287 std::string algo(algorithm.data(), algorithm.size());
288
289 auto runTyped = [&](auto type)
290 {
291 using ElementT = decltype(type);
292 if (use_cutlass) {
293 return sparse_semi_structured_tile_typed<ElementT, MetadataCutlass>(input, algo);
294 }
295 else {
296 return sparse_semi_structured_tile_typed<ElementT, MetadataCuSparseLt>(input, algo);
297 }
298 };
299
300 if (input.scalar_type() == at::ScalarType::Half)
301 {
302 return runTyped(cutlass::half_t());
303 } else {
304 TORCH_CHECK(
305 input.scalar_type() == at::ScalarType::Half ||
306 input.scalar_type() == at::ScalarType::BFloat16, input.scalar_type());
307 return runTyped(cutlass::bfloat16_t());
308 }
309 #endif
310 }
311
312 } // namespace at::native
313