xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ScalarOps.h>
2 #include <ATen/Functions.h>
3 #include <ATen/Tensor.h>
4 #include <ATen/autocast_mode.h>
5 #include <c10/cuda/CUDAGuard.h>
6 #include <ATen/ATen.h>
7 #include <ATen/core/Tensor.h>
8 #include <ATen/cuda/CUDAUtils.h>
9 #include <ATen/Dispatch.h>
10 
11 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
12 #else
13 #include <ATen/native/sparse/cuda/ComputeSparseTile.h>
14 #include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
15 #include <cuda_runtime.h>
16 #endif
17 
18 namespace at::native {
19 
20 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
21 #else
22 struct MetadataCuSparseLt {
23   // Format used by cuSparseLt
24   // This is based on reverse-engineering, for a visual illustration:
25   // https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g29afe95bda8_0_0
26   static constexpr int kStrideBlock32x32 = (32 * 32) / (sizeof(ElementInputE) * 8);
27 
28   ElementInputE* _meta;
29   ElementInputE* _meta_trans;
30   int64_t _rows;
31   int64_t _cols;
32 
33   static int64_t getMetadataSize(int rows, int cols)
34   {
35     TORCH_CHECK(rows % 128 == 0 && cols % 128 == 0, "Only supports rows/cols multiples of 128");
36     // 1 bit per dense value
37     return (rows * cols) / (8 * sizeof(ElementInputE));
38   }
39 
40   // < return value of the function, packed, packed_meta >
41   static std::tuple<Tensor, Tensor, Tensor> create_compressed_representation(int rows, int cols, at::Tensor const& like)
42   {
43     TORCH_CHECK(
44         like.scalar_type() == at::ScalarType::Half ||
45         like.scalar_type() == at::ScalarType::BFloat16);
46     constexpr int kBytesPerScalar = 2;
47     int64_t data_scalars = rows * cutlass::ceil_div(cols, 2);
48     int64_t meta_scalars = getMetadataSize(rows, cols);
49 
50     at::Tensor storage = at::empty(
51         {(data_scalars + meta_scalars)},
52         at::TensorOptions().device(like.device()).dtype(like.dtype()));
53 
54     using at::indexing::Slice;
55     using at::indexing::None;
56     at::Tensor packed = storage.index({Slice(None, data_scalars)})
57                             .view({rows, cutlass::ceil_div(cols, 2)});
58     at::Tensor metadata = storage.index({Slice(data_scalars, None)});
59     // TODO: Cast metadata to Short
60     static_assert(kBytesPerScalar == 2, "or modify the last dim below");
61     metadata = metadata.view({rows / 128, cols / 32, 256});
62     return std::make_tuple(storage, packed, metadata);
63   }
64 
65   MetadataCuSparseLt(at::Tensor metaN, at::Tensor metaT, int rows, int cols) {
66     _meta = (ElementInputE*)metaN.data_ptr();
67     _meta_trans = (ElementInputE*)metaT.data_ptr();
68     _rows = rows;
69     _cols = cols;
70   }
71   CUTLASS_HOST_DEVICE
72   static int64_t _get_meta_offset(
73       int warp_row,
74       int thread_row,
75       int warp_col,
76       int thread_col,
77       int totalRows) {
78     int64_t offset = 0;
79     // warp-level: Find the 128x64 tile
80     offset += (warp_row / 128) * (kStrideBlock32x32 * 8);
81     offset += (warp_col / 64) * (kStrideBlock32x32 * 8) * (totalRows / 128);
82     // Find the 32x32 tile inside
83     offset += (((warp_row + thread_row) % 128) / 32) * kStrideBlock32x32;
84     offset += (((warp_col + thread_col) % 64) / 32) * (kStrideBlock32x32 * 4);
85     // Inside the 32x32 tile
86     offset += (warp_row % 32) * 2;
87     // Top/bottom 16x16 tile
88     offset += ((thread_row % 32) / 16) * 4;
89     // Left/right 16x16 tile
90     offset += ((thread_col % 32) / 16) * 2;
91     return offset;
92   }
93   CUTLASS_HOST_DEVICE
94   ElementInputE* get_metaN(
95       int warp_row,
96       int thread_row,
97       int warp_col,
98       int thread_col) const {
99     return _meta +
100         _get_meta_offset(warp_row, thread_row, warp_col, thread_col, _rows);
101   }
102   CUTLASS_HOST_DEVICE
103   ElementInputE* get_metaT(
104       int warp_row,
105       int thread_row,
106       int warp_col,
107       int thread_col) const {
108     return _meta_trans +
109         _get_meta_offset(warp_col, thread_col, warp_row, thread_row, _cols);
110   }
111 };
112 
113 struct MetadataCutlass {
114   // Layout needed to run 2:4 gemms in CUTLASS
115   // There is basically a hardware specific value for every
116   // 32x32 dense tile (1024 bits). Then these tiles are
117   // stored in a Column-Major fashion
118   ElementInputE* _meta;
119   ElementInputE* _meta_trans;
120   int64_t _meta_reordered_sy;
121   int64_t _meta_trans_reordered_sx;
122 
123   static std::tuple<
124       at::Tensor, // return value of the function
125       at::Tensor, // packed
126       at::Tensor // packed_meta
127       >
128   create_compressed_representation(int rows, int cols, at::Tensor const& like) {
129     TORCH_CHECK(
130         like.scalar_type() == at::ScalarType::Half ||
131         like.scalar_type() == at::ScalarType::BFloat16);
132     auto roundedx = cutlass::round_up(rows, kWarpX);
133     auto roundedy = cutlass::round_up(cols, kWarpY);
134 
135     // NB: Writing to `packed` tensors in transposed manner
136     at::Tensor packed =
137         at::empty({roundedx, cutlass::ceil_div(roundedy, 2)}, like.options());
138     at::Tensor packed_meta = at::empty(
139                                  {roundedx * roundedy / 16},
140                                  like.options().dtype(at::ScalarType::Short))
141                                  .view({roundedy / 32, roundedx, 2})
142                                  .permute({1, 2, 0});
143     return std::make_tuple(packed, packed, packed_meta);
144   }
145   MetadataCutlass(at::Tensor metaN, at::Tensor metaT, int rows, int cols) {
146     _meta = (ElementInputE*)metaN.data_ptr();
147     _meta_reordered_sy = metaN.stride(2);
148     _meta_trans = (ElementInputE*)metaT.data_ptr();
149     _meta_trans_reordered_sx = metaT.stride(2);
150   }
151   CUTLASS_HOST_DEVICE
152   int64_t _get_meta_offset(
153       int warp_row,
154       int thread_row,
155       int warp_col,
156       int thread_col,
157       int64_t stride) const {
158     int64_t offset = 0;
159     offset += warp_row * 2 + (warp_col / 32) * stride;
160     // A single warp is 32x64. The right 32x32 tile is at a different position
161     offset += 64 * (thread_row / 32);
162     offset += (thread_col / 32) * stride;
163     // Top/bottom 16x16 tile
164     offset += ((thread_row % 32) / 16) * 4;
165     // Left/right 16x16 tile
166     offset += ((thread_col % 32) / 16) * 2;
167     return offset;
168   }
169   CUTLASS_HOST_DEVICE
170   ElementInputE* get_metaN(
171       int warp_row,
172       int thread_row,
173       int warp_col,
174       int thread_col) const {
175     return _meta +
176         _get_meta_offset(
177                warp_row, thread_row, warp_col, thread_col, _meta_reordered_sy);
178   }
179   CUTLASS_HOST_DEVICE
180   ElementInputE* get_metaT(
181       int warp_row,
182       int thread_row,
183       int warp_col,
184       int thread_col) const {
185     return _meta_trans +
186         _get_meta_offset(
187                warp_col,
188                thread_col,
189                warp_row,
190                thread_row,
191                _meta_trans_reordered_sx);
192   }
193 };
194 
195 template <typename KT, typename Metadata, typename Algorithm>
196 __global__ void __launch_bounds__(32 /* num_threads */, 20)
197     sparse_semi_structured_tile_kernel(
198         typename KT::Params p,
199         Metadata metadata,
200         Algorithm algo) {
201   KT::sparse_semi_structured_tile_kernel(p, metadata, algo);
202 }
203 
204 template <typename Element, typename MetadataFormat>
205 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> sparse_semi_structured_tile_typed(
206         const at::Tensor input,
207         std::string algorithm)
208 {
209   using KT = KernelTypes<Element>;
210   std::optional<at::cuda::CUDAGuard> device_guard;
211   if (!input.is_meta()) {
212     device_guard.emplace(input.device());
213   }
214 
215   TORCH_CHECK(input.dim() == 2, "Can only sparsify 2d tensors");
216   TORCH_CHECK(
217       input.stride(1) == 1,
218       "Can only sparsify contiguous tensors. Sparsify the transpose otherwise.");
219 
220   auto rows = input.size(0);
221   auto cols = input.size(1);
222 
223   auto [compressed, packed, packed_meta_reordered] =
224       MetadataFormat::create_compressed_representation(rows, cols, input);
225   auto [compressed_trans, packed_trans, packed_trans_meta_reordered] =
226       MetadataFormat::create_compressed_representation(cols, rows, input);
227   TORCH_CHECK(
228       input.size(1) % 32 == 0, "Number of cols should be multiple of 32");
229 
230   typename KT::Params p;
231   p.input = (Element const*)input.data_ptr();
232   p.input_s0 = input.stride(0);
233   p.input_dim0 = input.size(0);
234   p.input_dim1 = input.size(1);
235 
236   p.packed = (Element*)packed.data_ptr();
237   p.packed_stride = packed.stride(0);
238   p.packed_trans = (Element*)packed_trans.data_ptr();
239   p.packed_trans_stride = packed_trans.stride(0);
240 
241   MetadataFormat metadata = MetadataFormat(
242       packed_meta_reordered, packed_trans_meta_reordered, rows, cols);
243   at::Tensor threads_masks = at::empty(
244       {p.getBlocksGrid().x * p.getThreadsGrid().x,
245        p.getBlocksGrid().y * p.getThreadsGrid().y,
246        sizeof(p.threads_masks[0])},
247       input.options().dtype(at::ScalarType::Byte));
248   p.threads_masks = (uint64_t*)threads_masks.data_ptr();
249 
250   bool kernel_launched = false;
251   auto launchKernel = [&](auto algo, std::string const& algo_name) {
252     if (algo_name == algorithm) {
253       kernel_launched = true;
254       if (input.is_meta()) {
255         return;
256       }
257       size_t smem_bytes = 0;
258       sparse_semi_structured_tile_kernel<KT>
259           <<<p.getBlocksGrid(),
260              p.getThreadsGrid(),
261              smem_bytes,
262              at::cuda::getCurrentCUDAStream()>>>(p, metadata, algo);
263     }
264   };
265   named_algorithms(launchKernel);
266   TORCH_CHECK(kernel_launched, "Unknown algorithm \"", algorithm, "\"");
267   C10_CUDA_KERNEL_LAUNCH_CHECK();
268   return std::make_tuple(
269       compressed,
270       packed_meta_reordered,
271       compressed_trans,
272       packed_trans_meta_reordered,
273       threads_masks);
274 }
275 #endif
276 
277 // <packed, packed_meta_reordered, packed_trans, packed_trans_meta_reorderd, threads_masks>
_sparse_semi_structured_tile(const Tensor & input,c10::string_view algorithm,bool use_cutlass)278 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _sparse_semi_structured_tile(
279   const Tensor& input,
280   c10::string_view algorithm,
281   bool use_cutlass)
282 {
283 #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
284   AT_ERROR("_sparse_semi_structured_tile: not supported");
285   return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{});
286 #else
287   std::string algo(algorithm.data(), algorithm.size());
288 
289   auto runTyped = [&](auto type)
290   {
291     using ElementT = decltype(type);
292     if (use_cutlass) {
293       return sparse_semi_structured_tile_typed<ElementT, MetadataCutlass>(input, algo);
294     }
295     else {
296       return sparse_semi_structured_tile_typed<ElementT, MetadataCuSparseLt>(input, algo);
297     }
298   };
299 
300   if (input.scalar_type() == at::ScalarType::Half)
301   {
302     return runTyped(cutlass::half_t());
303   } else {
304     TORCH_CHECK(
305         input.scalar_type() == at::ScalarType::Half ||
306         input.scalar_type() == at::ScalarType::BFloat16, input.scalar_type());
307     return runTyped(cutlass::bfloat16_t());
308   }
309 #endif
310 }
311 
312 } // namespace at::native
313