xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredPack.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/sparse/cuda/StaticSort.h>
4 #include <cutlass/arch/memory.h>
5 #include <cutlass/array.h>
6 #include <cutlass/bfloat16.h>
7 #include <cutlass/fast_math.h>
8 #include <cutlass/half.h>
9 #include <cutlass/integer_subbyte.h>
10 
11 namespace at::native {
12 
13 using cutlass::uint1b_t;
14 using cutlass::uint2b_t;
15 using cutlass::uint4b_t;
16 using uint8b_t = cutlass::integer_subbyte<8, false>;
17 using ReorderedLayoutInputE = cutlass::layout::ColumnMajorInterleaved<2>;
18 using ElementInputE = uint16_t;
19 constexpr int kWarpX = 32;
20 constexpr int kWarpY = 64;
21 constexpr int kThreadX = 8;
22 constexpr int kThreadY = 8;
23 
24 // bitmask of selected values, in col-major storage
25 // eg: indices & (1 << (col + 4 * row))
26 using Indices4x4 = uint16_t;
27 
28 struct Tile8x8Masks {
29   Indices4x4 a, b, c, d;
Tile8x8MasksTile8x8Masks30   CUTLASS_DEVICE Tile8x8Masks() {
31     a = b = c = d = 0;
32   }
33 };
34 
35 static_assert(sizeof(Tile8x8Masks) == 8, "should be exactly uint64_t");
36 
37 // Each thread has data for an 8x8 area of the input tensor
38 // Due to the very specific format of the metadata, 32 consecutive bits
39 // of the metadata tensor will live in 4 different threads.
40 // This functions does the required warp shuffling to send data to the
41 // right threads.
42 // This took some time to write (and get right), hopefully these slides
43 // can help
44 // https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g249eb2e2f2e_0_28
45 CUTLASS_DEVICE uint32_t
46 warp_shuffle_meta(uint32_t meta_ab, bool transposed = false) {
47   // The required format is
48   // (one line = 32 bits)
49   // a[ 0,  0:16] a[ 8,  0:16] <- T0 [left]
50   // a[ 0, 16:32] a[ 8, 16:32]
51   // a[16,  0:16] a[24,  0:16]
52   // a[16, 16:32] a[24, 16:32]
53   // a[ 1,  0:16] a[ 9,  0:16] <- T4
54   // a[ 1, 16:32] a[ 9, 16:32]
55   // a[17,  0:16] a[25,  0:16]
56   // a[17, 16:32] a[25, 16:32]
57   // a[ 2,  0:16] a[10,  0:16] <- T1 [left, bottom]
58   // a[ 2, 16:32] a[10, 16:32]
59   // a[18,  0:16] a[26,  0:16]
60   // a[18, 16:32] a[26, 16:32]
61   // a[ 3,  0:16] a[11,  0:16] <- T5 [bottom]
62   // a[ 3, 16:32] a[11, 16:32]
63   // a[19,  0:16] a[27,  0:16]
64   // a[19, 16:32] a[27, 16:32]
65   // ...
66   // Use warp-shuffles to send data around threads
67   bool thread_left = (threadIdx.y % 2) == 0;
68   bool thread_bottom = threadIdx.x % 2;
69 
70   if (transposed) {
71     thread_left = (threadIdx.x % 2) == 0;
72     thread_bottom = threadIdx.y % 2;
73   }
74 
75   uint8b_t stage0_data[2] = {
76       uint8b_t(meta_ab >> (8 * thread_left)),
77       uint8b_t(meta_ab >> (8 * (thread_left + 2)))};
78   // shfl t0-t4 / t1-t5
79   stage0_data[0] =
80       uint8b_t(__shfl_xor_sync(0xffffffff, stage0_data[0], transposed ? 1 : 4));
81   stage0_data[1] =
82       uint8b_t(__shfl_xor_sync(0xffffffff, stage0_data[1], transposed ? 1 : 4));
83 
84   uint16_t line0 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left))))
85       << ((1 - thread_left) * 8);
86   line0 |= int(stage0_data[0]) << (thread_left * 8);
87   uint16_t line1 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left + 2))))
88       << ((1 - thread_left) * 8);
89   line1 |= int(stage0_data[1]) << (thread_left * 8);
90 
91   uint16_t stage1_data = thread_bottom ? line0 : line1;
92   stage1_data = __shfl_xor_sync(0xffffffff, stage1_data, transposed ? 4 : 1);
93 
94   uint32_t final_metadata;
95   if (thread_bottom) {
96     final_metadata = uint32_t(stage1_data) | uint32_t(line1) << 16;
97   } else {
98     final_metadata = uint32_t(stage1_data) << 16 | uint32_t(line0);
99   }
100   return final_metadata;
101 }
102 
103 CUTLASS_DEVICE void warp_shuffle_and_write_meta(
104     ElementInputE* metadata_quad,
105     uint32_t meta_ab,
106     bool transposed = false) {
107   bool thread_left = (threadIdx.y % 2) == 0;
108   bool thread_bottom = threadIdx.x % 2;
109 
110   if (transposed) {
111     thread_left = (threadIdx.x % 2) == 0;
112     thread_bottom = threadIdx.y % 2;
113   }
114 
115   uint32_t final_metadata = warp_shuffle_meta(meta_ab, transposed);
116 
117   int index = (!thread_left + 2 * thread_bottom) * 4;
118   ((uint32_t*)metadata_quad)[index] = final_metadata;
119 }
120 
121 template <typename Element_>
122 struct KernelTypes {
123   using Element = Element_;
124   using Fragment =
125       cutlass::Array<Element, 8>; // always read from gmem in chunks of 128bits
126   using Fragment4 = cutlass::Array<Element, 4>;
127   using ValuesPacked = cutlass::Array<Element, 8>; // 4 first col, 4 second col
128 
129   struct Params {
130     /// inputs
131     Element const* input;
132     int64_t input_s0;
133     int64_t input_dim0;
134     int64_t input_dim1;
135 
136     /// outputs
137     Element* packed;
138     int64_t packed_stride;
139 
140     Element* packed_trans;
141     int64_t packed_trans_stride;
142 
143     uint64_t* threads_masks;
144 
getBlocksGridKernelTypes::Params145     __host__ dim3 getBlocksGrid() const {
146       return dim3(
147           cutlass::ceil_div(input_dim0, kWarpX),
148           cutlass::ceil_div(input_dim1, kWarpY),
149           1);
150     }
151 
getThreadsGridKernelTypes::Params152     static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() {
153       return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1);
154     }
155 
getCurrentThreadIndicesKernelTypes::Params156     CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const {
157       Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks;
158       gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y;
159       int64_t strideX = gridDim.y * getThreadsGrid().y;
160       gmem_threads_masks +=
161           (blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX;
162       return gmem_threads_masks;
163     }
164   };
165 
166   struct Tile4x4Accessor {
167     using Element = Element_;
168 
169     Fragment (&_lines)[8];
170     int _start_row;
171     int _start_col;
172 
Tile4x4AccessorKernelTypes::Tile4x4Accessor173     CUTLASS_DEVICE Tile4x4Accessor(
174         Fragment (&lines)[8],
175         int start_row,
176         int start_col)
177         : _lines(lines), _start_row(start_row), _start_col(start_col) {}
178 
atKernelTypes::Tile4x4Accessor179     CUTLASS_DEVICE typename Fragment::reference at(int r, int c) {
180       return _lines[r + _start_row][c + _start_col];
181     }
182   };
183 
184   struct Tile4x4Packed {
185     Fragment4 values[2];
Tile4x4PackedKernelTypes::Tile4x4Packed186     CUTLASS_DEVICE Tile4x4Packed() {
187       values[0].clear();
188       values[1].clear();
189     }
190   };
191 
192   // Returns a packed 4x4 tile (eg 2x4 values) which correspond to the values
193   // that are in `indices`. Also fills the `meta` array in the right format
194   // for consumption in the TensorCores.
195   // Example:
196   //  indices:  0011
197   //            1001
198   //            1001
199   //            0100 (<- note, only 1 value on the last line)
200   //  packed: values[0][2] values[1][0] values[2][0] values[3][1]
201   //          values[0][3] values[1][3] values[2][3] Element(0)
202   CUTLASS_DEVICE static Tile4x4Packed pack_4x4(
203       Indices4x4 indices,
204       Tile4x4Accessor tile,
205       uint32_t& meta,
206       int meta_pos,
207       bool transpose = false) {
208     Tile4x4Packed packed;
209     CUTLASS_PRAGMA_UNROLL
210     for (int row = 0; row < 4; ++row) {
211       uint2b_t col0_from, col1_from;
212       auto packValue = [&](uint2b_t col_to, uint2b_t col_from) {
213         auto value = transpose ? tile.at(col_from, row).get()
214                                : tile.at(row, col_from).get();
215         packed.values[col_to][row] = value;
216         if (col_to == uint2b_t(0)) {
217           col0_from = col_from;
218         } else {
219           col1_from = col_from;
220         }
221       };
222       auto isSelected = [&](int col) {
223         if (transpose) {
224           return indices & (1 << (row + 4 * col));
225         }
226         return indices & (1 << (col + 4 * row));
227       };
228       // Process cols 0/1
229       // We know that col0 is always packed to position 0 if it's there
230       // and col1 is packed to pos 0 or 1 (depending if col0 is selected)
231       if (isSelected(1)) {
232         packValue(uint2b_t(0), uint2b_t(1));
233       }
234       if (isSelected(0)) {
235         packValue(uint2b_t(0), uint2b_t(0));
236       }
237       if (isSelected(0) && isSelected(1)) {
238         packValue(uint2b_t(1), uint2b_t(1));
239       }
240       // Process cols 2/3
241       // same sort of heuristic
242       if (isSelected(2)) {
243         packValue(uint2b_t(1), uint2b_t(2));
244       }
245       if (isSelected(3)) {
246         packValue(uint2b_t(1), uint2b_t(3));
247       }
248       if (isSelected(2) && isSelected(3)) {
249         packValue(uint2b_t(0), uint2b_t(2));
250       }
251       int add_mask = (col0_from | (col1_from << 2)) << (8 * row + meta_pos);
252       meta |= add_mask;
253     }
254     return packed;
255   }
256 
257   struct Tile8x8Meta {
258     // meta_ab[row] |= (real_col << (8*row + 2*pos))
259     uint32_t meta_ab;
260     uint32_t meta_cd;
261 
262     // meta_ac_trans[col] |= (real_row << (8*col + 2*pos))
263     uint32_t meta_ac_trans;
264     uint32_t meta_bd_trans;
265 
Tile8x8MetaKernelTypes::Tile8x8Meta266     CUTLASS_DEVICE Tile8x8Meta() {
267       meta_ab = meta_cd = meta_ac_trans = meta_bd_trans = 0;
268     }
269   };
270 
writePackedKernelTypes271   CUTLASS_DEVICE static void writePacked(
272       Element* ptr,
273       Fragment4 packed0,
274       Fragment4 packed1) {
275     Fragment write;
276     CUTLASS_PRAGMA_UNROLL
277     for (int i = 0; i < 4; ++i) {
278       write[i] = packed0[i].get();
279       write[i + 4] = packed1[i].get();
280     }
281     cutlass::arch::global_store<Fragment, sizeof(Fragment)>(write, ptr, true);
282   }
283 
writePackedTKernelTypes284   CUTLASS_DEVICE static void writePackedT(
285       Element* ptr,
286       int64_t stride,
287       Tile4x4Packed a,
288       Tile4x4Packed b) {
289     CUTLASS_PRAGMA_UNROLL
290     for (int i = 0; i < 4; ++i) {
291       Fragment4 write;
292       write[0] = a.values[0][i].get();
293       write[1] = a.values[1][i].get();
294       write[2] = b.values[0][i].get();
295       write[3] = b.values[1][i].get();
296       cutlass::arch::global_store<Fragment4, sizeof(Fragment4)>(
297           write, ptr + i * stride, true);
298     }
299   }
300 
301   template <typename Algorithm, typename MetadataStore>
sparse_semi_structured_tile_kernelKernelTypes302   CUTLASS_DEVICE static void sparse_semi_structured_tile_kernel(
303       Params p,
304       MetadataStore metadata_gmem,
305       Algorithm compute_tile_indices) {
306     // Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
307     // A, B, C and D, as displayed in the following schema:
308     // +---+---+
309     // | A | B |
310     // +---+---+
311     // | C | D |
312     // +---+---+
313     // Each warp (32 threads) will then be responsible for a 32x64 tile of the
314     // input.
315     // This configuration allows to read/write data in 128bits chunks. These
316     // memory accesses are coalesced at the warp-level into 128bytes. See also:
317     // https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g2494f30c7cf_0_0
318 
319     // Top-left of the 8x8 tile we own
320     int warp_x = blockIdx.x * kWarpX;
321     int warp_y = blockIdx.y * kWarpY;
322     int x = warp_x + threadIdx.x * kThreadX;
323     int y = warp_y + threadIdx.y * kThreadY;
324 
325     Element const* input = p.input + x * p.input_s0 + y;
326     Element* packed = p.packed + x * p.packed_stride + (y / 2);
327     Element* packed_trans =
328         p.packed_trans + (x / 2) + y * p.packed_trans_stride;
329 
330     Fragment lines[8]; // Contains all values from the 8x8 tile
331 
332     Tile8x8Meta metadata;
333     Tile8x8Masks indices;
334 
335     // Load/process tiles `A` and `B`
336     Element fillValue = Algorithm::template outOfBoundsFillValue<Element>();
337     CUTLASS_PRAGMA_UNROLL
338     for (int i = 0; i < 4; ++i) {
339       lines[i].fill(fillValue);
340       cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
341           lines[i], input + i * p.input_s0, x + i < p.input_dim0);
342     }
343     indices.a = compute_tile_indices(Tile4x4Accessor(lines, 0, 0));
344     indices.b = compute_tile_indices(Tile4x4Accessor(lines, 0, 4));
345 
346     // Compute packed tiles A & B
347     {
348       Tile4x4Packed packed_a = pack_4x4(
349           indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
350       Tile4x4Packed packed_b = pack_4x4(
351           indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
352       writePackedT(packed, p.packed_stride, packed_a, packed_b);
353     }
354 
355     // Compute/store packed tiles A & B in transpose output
356     Tile4x4Packed packed_trans_a = pack_4x4(
357         indices.a,
358         Tile4x4Accessor(lines, 0, 0),
359         metadata.meta_ac_trans,
360         0,
361         true);
362     Tile4x4Packed packed_trans_b = pack_4x4(
363         indices.b,
364         Tile4x4Accessor(lines, 0, 4),
365         metadata.meta_bd_trans,
366         0,
367         true);
368     // (NOTE) Now we no longer need A & B (`lines[0:4]`)
369 
370     // Load/process tiles `C` and `D`
371     CUTLASS_PRAGMA_UNROLL
372     for (int i = 4; i < 8; ++i) {
373       lines[i].fill(fillValue);
374       cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
375           lines[i], input + i * p.input_s0, x + i < p.input_dim0);
376     }
377     indices.c = compute_tile_indices(Tile4x4Accessor(lines, 4, 0));
378     indices.d = compute_tile_indices(Tile4x4Accessor(lines, 4, 4));
379 
380     // Compute packed tiles C & D
381     {
382       Tile4x4Packed packed_c = pack_4x4(
383           indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
384       Tile4x4Packed packed_d = pack_4x4(
385           indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
386       writePackedT(
387           packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
388     }
389 
390     // Compute/store packed tiles C & D in transpose output
391     Tile4x4Packed packed_trans_c = pack_4x4(
392         indices.c,
393         Tile4x4Accessor(lines, 4, 0),
394         metadata.meta_ac_trans,
395         4,
396         true);
397     Tile4x4Packed packed_trans_d = pack_4x4(
398         indices.d,
399         Tile4x4Accessor(lines, 4, 4),
400         metadata.meta_bd_trans,
401         4,
402         true);
403 
404     // Dump the metadata in a nice format
405     *p.getCurrentThreadIndices() = indices;
406 
407     // Store packed A, B, C & D for transposed matrix
408     writePackedT(
409         packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
410     packed_trans += 4 * p.packed_trans_stride;
411     writePackedT(
412         packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
413 
414     // Writing meta non-transposed
415     {
416       ElementInputE* packed_meta_reordered = metadata_gmem.get_metaN(
417           warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
418       warp_shuffle_and_write_meta(packed_meta_reordered, metadata.meta_ab);
419       warp_shuffle_and_write_meta(packed_meta_reordered + 32, metadata.meta_cd);
420     }
421 
422     // Writing meta transposed
423     {
424       ElementInputE* packed_trans_meta_reordered = metadata_gmem.get_metaT(
425           warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
426       warp_shuffle_and_write_meta(
427           packed_trans_meta_reordered, metadata.meta_ac_trans, true);
428       warp_shuffle_and_write_meta(
429           packed_trans_meta_reordered + 32, metadata.meta_bd_trans, true);
430     }
431   }
432 
sparse_semi_structured_apply_kernelKernelTypes433   CUTLASS_DEVICE static void sparse_semi_structured_apply_kernel(Params p) {
434     // See `sparse24_sparsify_both_ways_kernel`
435     // It's basically the same, just that we skip
436     // the part where compute the indices we keep
437 
438     // Top-left of the 8x8 tile we own
439     int warp_x = blockIdx.x * kWarpX;
440     int warp_y = blockIdx.y * kWarpY;
441     int x = warp_x + threadIdx.x * kThreadX;
442     int y = warp_y + threadIdx.y * kThreadY;
443 
444     Element const* input = p.input + x * p.input_s0 + y;
445     Element* packed = p.packed + x * p.packed_stride + (y / 2);
446     Element* packed_trans =
447         p.packed_trans + (x / 2) + y * p.packed_trans_stride;
448 
449     Fragment lines[8]; // Contains all values from the 8x8 tile
450 
451     Tile8x8Meta metadata;
452     Tile8x8Masks indices = *p.getCurrentThreadIndices();
453 
454     // Load/process tiles `A` and `B`
455     CUTLASS_PRAGMA_UNROLL
456     for (int i = 0; i < 8; ++i) {
457       // NB: Values outside bounds is undefined, but shouldn't
458       // be used anywhere
459       cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
460           lines[i], input + i * p.input_s0, x + i < p.input_dim0);
461     }
462 
463     // Compute packed tiles A & B
464     {
465       Tile4x4Packed packed_a = pack_4x4(
466           indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
467       Tile4x4Packed packed_b = pack_4x4(
468           indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
469       writePackedT(packed, p.packed_stride, packed_a, packed_b);
470     }
471 
472     // Compute/store packed tiles A & B in transpose output
473     Tile4x4Packed packed_trans_a = pack_4x4(
474         indices.a,
475         Tile4x4Accessor(lines, 0, 0),
476         metadata.meta_ac_trans,
477         0,
478         true);
479     Tile4x4Packed packed_trans_b = pack_4x4(
480         indices.b,
481         Tile4x4Accessor(lines, 0, 4),
482         metadata.meta_bd_trans,
483         0,
484         true);
485     // (NOTE) Now we no longer need A & B (`lines[0:4]`)
486 
487     // Compute packed tiles C & D
488     {
489       Tile4x4Packed packed_c = pack_4x4(
490           indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
491       Tile4x4Packed packed_d = pack_4x4(
492           indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
493       writePackedT(
494           packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
495     }
496 
497     // Compute/store packed tiles C & D in transpose output
498     Tile4x4Packed packed_trans_c = pack_4x4(
499         indices.c,
500         Tile4x4Accessor(lines, 4, 0),
501         metadata.meta_ac_trans,
502         4,
503         true);
504     Tile4x4Packed packed_trans_d = pack_4x4(
505         indices.d,
506         Tile4x4Accessor(lines, 4, 4),
507         metadata.meta_bd_trans,
508         4,
509         true);
510 
511     // Store packed A, B, C & D for transposed matrix
512     writePackedT(
513         packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
514     packed_trans += 4 * p.packed_trans_stride;
515     writePackedT(
516         packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
517   }
518 };
519 
520 } // namespace at::native
521