1 #pragma once 2 3 #include <ATen/native/sparse/cuda/StaticSort.h> 4 #include <cutlass/arch/memory.h> 5 #include <cutlass/array.h> 6 #include <cutlass/bfloat16.h> 7 #include <cutlass/fast_math.h> 8 #include <cutlass/half.h> 9 #include <cutlass/integer_subbyte.h> 10 11 namespace at::native { 12 13 using cutlass::uint1b_t; 14 using cutlass::uint2b_t; 15 using cutlass::uint4b_t; 16 using uint8b_t = cutlass::integer_subbyte<8, false>; 17 using ReorderedLayoutInputE = cutlass::layout::ColumnMajorInterleaved<2>; 18 using ElementInputE = uint16_t; 19 constexpr int kWarpX = 32; 20 constexpr int kWarpY = 64; 21 constexpr int kThreadX = 8; 22 constexpr int kThreadY = 8; 23 24 // bitmask of selected values, in col-major storage 25 // eg: indices & (1 << (col + 4 * row)) 26 using Indices4x4 = uint16_t; 27 28 struct Tile8x8Masks { 29 Indices4x4 a, b, c, d; Tile8x8MasksTile8x8Masks30 CUTLASS_DEVICE Tile8x8Masks() { 31 a = b = c = d = 0; 32 } 33 }; 34 35 static_assert(sizeof(Tile8x8Masks) == 8, "should be exactly uint64_t"); 36 37 // Each thread has data for an 8x8 area of the input tensor 38 // Due to the very specific format of the metadata, 32 consecutive bits 39 // of the metadata tensor will live in 4 different threads. 40 // This functions does the required warp shuffling to send data to the 41 // right threads. 42 // This took some time to write (and get right), hopefully these slides 43 // can help 44 // https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g249eb2e2f2e_0_28 45 CUTLASS_DEVICE uint32_t 46 warp_shuffle_meta(uint32_t meta_ab, bool transposed = false) { 47 // The required format is 48 // (one line = 32 bits) 49 // a[ 0, 0:16] a[ 8, 0:16] <- T0 [left] 50 // a[ 0, 16:32] a[ 8, 16:32] 51 // a[16, 0:16] a[24, 0:16] 52 // a[16, 16:32] a[24, 16:32] 53 // a[ 1, 0:16] a[ 9, 0:16] <- T4 54 // a[ 1, 16:32] a[ 9, 16:32] 55 // a[17, 0:16] a[25, 0:16] 56 // a[17, 16:32] a[25, 16:32] 57 // a[ 2, 0:16] a[10, 0:16] <- T1 [left, bottom] 58 // a[ 2, 16:32] a[10, 16:32] 59 // a[18, 0:16] a[26, 0:16] 60 // a[18, 16:32] a[26, 16:32] 61 // a[ 3, 0:16] a[11, 0:16] <- T5 [bottom] 62 // a[ 3, 16:32] a[11, 16:32] 63 // a[19, 0:16] a[27, 0:16] 64 // a[19, 16:32] a[27, 16:32] 65 // ... 66 // Use warp-shuffles to send data around threads 67 bool thread_left = (threadIdx.y % 2) == 0; 68 bool thread_bottom = threadIdx.x % 2; 69 70 if (transposed) { 71 thread_left = (threadIdx.x % 2) == 0; 72 thread_bottom = threadIdx.y % 2; 73 } 74 75 uint8b_t stage0_data[2] = { 76 uint8b_t(meta_ab >> (8 * thread_left)), 77 uint8b_t(meta_ab >> (8 * (thread_left + 2)))}; 78 // shfl t0-t4 / t1-t5 79 stage0_data[0] = 80 uint8b_t(__shfl_xor_sync(0xffffffff, stage0_data[0], transposed ? 1 : 4)); 81 stage0_data[1] = 82 uint8b_t(__shfl_xor_sync(0xffffffff, stage0_data[1], transposed ? 1 : 4)); 83 84 uint16_t line0 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left)))) 85 << ((1 - thread_left) * 8); 86 line0 |= int(stage0_data[0]) << (thread_left * 8); 87 uint16_t line1 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left + 2)))) 88 << ((1 - thread_left) * 8); 89 line1 |= int(stage0_data[1]) << (thread_left * 8); 90 91 uint16_t stage1_data = thread_bottom ? line0 : line1; 92 stage1_data = __shfl_xor_sync(0xffffffff, stage1_data, transposed ? 4 : 1); 93 94 uint32_t final_metadata; 95 if (thread_bottom) { 96 final_metadata = uint32_t(stage1_data) | uint32_t(line1) << 16; 97 } else { 98 final_metadata = uint32_t(stage1_data) << 16 | uint32_t(line0); 99 } 100 return final_metadata; 101 } 102 103 CUTLASS_DEVICE void warp_shuffle_and_write_meta( 104 ElementInputE* metadata_quad, 105 uint32_t meta_ab, 106 bool transposed = false) { 107 bool thread_left = (threadIdx.y % 2) == 0; 108 bool thread_bottom = threadIdx.x % 2; 109 110 if (transposed) { 111 thread_left = (threadIdx.x % 2) == 0; 112 thread_bottom = threadIdx.y % 2; 113 } 114 115 uint32_t final_metadata = warp_shuffle_meta(meta_ab, transposed); 116 117 int index = (!thread_left + 2 * thread_bottom) * 4; 118 ((uint32_t*)metadata_quad)[index] = final_metadata; 119 } 120 121 template <typename Element_> 122 struct KernelTypes { 123 using Element = Element_; 124 using Fragment = 125 cutlass::Array<Element, 8>; // always read from gmem in chunks of 128bits 126 using Fragment4 = cutlass::Array<Element, 4>; 127 using ValuesPacked = cutlass::Array<Element, 8>; // 4 first col, 4 second col 128 129 struct Params { 130 /// inputs 131 Element const* input; 132 int64_t input_s0; 133 int64_t input_dim0; 134 int64_t input_dim1; 135 136 /// outputs 137 Element* packed; 138 int64_t packed_stride; 139 140 Element* packed_trans; 141 int64_t packed_trans_stride; 142 143 uint64_t* threads_masks; 144 getBlocksGridKernelTypes::Params145 __host__ dim3 getBlocksGrid() const { 146 return dim3( 147 cutlass::ceil_div(input_dim0, kWarpX), 148 cutlass::ceil_div(input_dim1, kWarpY), 149 1); 150 } 151 getThreadsGridKernelTypes::Params152 static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() { 153 return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1); 154 } 155 getCurrentThreadIndicesKernelTypes::Params156 CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const { 157 Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks; 158 gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y; 159 int64_t strideX = gridDim.y * getThreadsGrid().y; 160 gmem_threads_masks += 161 (blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX; 162 return gmem_threads_masks; 163 } 164 }; 165 166 struct Tile4x4Accessor { 167 using Element = Element_; 168 169 Fragment (&_lines)[8]; 170 int _start_row; 171 int _start_col; 172 Tile4x4AccessorKernelTypes::Tile4x4Accessor173 CUTLASS_DEVICE Tile4x4Accessor( 174 Fragment (&lines)[8], 175 int start_row, 176 int start_col) 177 : _lines(lines), _start_row(start_row), _start_col(start_col) {} 178 atKernelTypes::Tile4x4Accessor179 CUTLASS_DEVICE typename Fragment::reference at(int r, int c) { 180 return _lines[r + _start_row][c + _start_col]; 181 } 182 }; 183 184 struct Tile4x4Packed { 185 Fragment4 values[2]; Tile4x4PackedKernelTypes::Tile4x4Packed186 CUTLASS_DEVICE Tile4x4Packed() { 187 values[0].clear(); 188 values[1].clear(); 189 } 190 }; 191 192 // Returns a packed 4x4 tile (eg 2x4 values) which correspond to the values 193 // that are in `indices`. Also fills the `meta` array in the right format 194 // for consumption in the TensorCores. 195 // Example: 196 // indices: 0011 197 // 1001 198 // 1001 199 // 0100 (<- note, only 1 value on the last line) 200 // packed: values[0][2] values[1][0] values[2][0] values[3][1] 201 // values[0][3] values[1][3] values[2][3] Element(0) 202 CUTLASS_DEVICE static Tile4x4Packed pack_4x4( 203 Indices4x4 indices, 204 Tile4x4Accessor tile, 205 uint32_t& meta, 206 int meta_pos, 207 bool transpose = false) { 208 Tile4x4Packed packed; 209 CUTLASS_PRAGMA_UNROLL 210 for (int row = 0; row < 4; ++row) { 211 uint2b_t col0_from, col1_from; 212 auto packValue = [&](uint2b_t col_to, uint2b_t col_from) { 213 auto value = transpose ? tile.at(col_from, row).get() 214 : tile.at(row, col_from).get(); 215 packed.values[col_to][row] = value; 216 if (col_to == uint2b_t(0)) { 217 col0_from = col_from; 218 } else { 219 col1_from = col_from; 220 } 221 }; 222 auto isSelected = [&](int col) { 223 if (transpose) { 224 return indices & (1 << (row + 4 * col)); 225 } 226 return indices & (1 << (col + 4 * row)); 227 }; 228 // Process cols 0/1 229 // We know that col0 is always packed to position 0 if it's there 230 // and col1 is packed to pos 0 or 1 (depending if col0 is selected) 231 if (isSelected(1)) { 232 packValue(uint2b_t(0), uint2b_t(1)); 233 } 234 if (isSelected(0)) { 235 packValue(uint2b_t(0), uint2b_t(0)); 236 } 237 if (isSelected(0) && isSelected(1)) { 238 packValue(uint2b_t(1), uint2b_t(1)); 239 } 240 // Process cols 2/3 241 // same sort of heuristic 242 if (isSelected(2)) { 243 packValue(uint2b_t(1), uint2b_t(2)); 244 } 245 if (isSelected(3)) { 246 packValue(uint2b_t(1), uint2b_t(3)); 247 } 248 if (isSelected(2) && isSelected(3)) { 249 packValue(uint2b_t(0), uint2b_t(2)); 250 } 251 int add_mask = (col0_from | (col1_from << 2)) << (8 * row + meta_pos); 252 meta |= add_mask; 253 } 254 return packed; 255 } 256 257 struct Tile8x8Meta { 258 // meta_ab[row] |= (real_col << (8*row + 2*pos)) 259 uint32_t meta_ab; 260 uint32_t meta_cd; 261 262 // meta_ac_trans[col] |= (real_row << (8*col + 2*pos)) 263 uint32_t meta_ac_trans; 264 uint32_t meta_bd_trans; 265 Tile8x8MetaKernelTypes::Tile8x8Meta266 CUTLASS_DEVICE Tile8x8Meta() { 267 meta_ab = meta_cd = meta_ac_trans = meta_bd_trans = 0; 268 } 269 }; 270 writePackedKernelTypes271 CUTLASS_DEVICE static void writePacked( 272 Element* ptr, 273 Fragment4 packed0, 274 Fragment4 packed1) { 275 Fragment write; 276 CUTLASS_PRAGMA_UNROLL 277 for (int i = 0; i < 4; ++i) { 278 write[i] = packed0[i].get(); 279 write[i + 4] = packed1[i].get(); 280 } 281 cutlass::arch::global_store<Fragment, sizeof(Fragment)>(write, ptr, true); 282 } 283 writePackedTKernelTypes284 CUTLASS_DEVICE static void writePackedT( 285 Element* ptr, 286 int64_t stride, 287 Tile4x4Packed a, 288 Tile4x4Packed b) { 289 CUTLASS_PRAGMA_UNROLL 290 for (int i = 0; i < 4; ++i) { 291 Fragment4 write; 292 write[0] = a.values[0][i].get(); 293 write[1] = a.values[1][i].get(); 294 write[2] = b.values[0][i].get(); 295 write[3] = b.values[1][i].get(); 296 cutlass::arch::global_store<Fragment4, sizeof(Fragment4)>( 297 write, ptr + i * stride, true); 298 } 299 } 300 301 template <typename Algorithm, typename MetadataStore> sparse_semi_structured_tile_kernelKernelTypes302 CUTLASS_DEVICE static void sparse_semi_structured_tile_kernel( 303 Params p, 304 MetadataStore metadata_gmem, 305 Algorithm compute_tile_indices) { 306 // Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles: 307 // A, B, C and D, as displayed in the following schema: 308 // +---+---+ 309 // | A | B | 310 // +---+---+ 311 // | C | D | 312 // +---+---+ 313 // Each warp (32 threads) will then be responsible for a 32x64 tile of the 314 // input. 315 // This configuration allows to read/write data in 128bits chunks. These 316 // memory accesses are coalesced at the warp-level into 128bytes. See also: 317 // https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g2494f30c7cf_0_0 318 319 // Top-left of the 8x8 tile we own 320 int warp_x = blockIdx.x * kWarpX; 321 int warp_y = blockIdx.y * kWarpY; 322 int x = warp_x + threadIdx.x * kThreadX; 323 int y = warp_y + threadIdx.y * kThreadY; 324 325 Element const* input = p.input + x * p.input_s0 + y; 326 Element* packed = p.packed + x * p.packed_stride + (y / 2); 327 Element* packed_trans = 328 p.packed_trans + (x / 2) + y * p.packed_trans_stride; 329 330 Fragment lines[8]; // Contains all values from the 8x8 tile 331 332 Tile8x8Meta metadata; 333 Tile8x8Masks indices; 334 335 // Load/process tiles `A` and `B` 336 Element fillValue = Algorithm::template outOfBoundsFillValue<Element>(); 337 CUTLASS_PRAGMA_UNROLL 338 for (int i = 0; i < 4; ++i) { 339 lines[i].fill(fillValue); 340 cutlass::arch::global_load<Fragment, sizeof(Fragment)>( 341 lines[i], input + i * p.input_s0, x + i < p.input_dim0); 342 } 343 indices.a = compute_tile_indices(Tile4x4Accessor(lines, 0, 0)); 344 indices.b = compute_tile_indices(Tile4x4Accessor(lines, 0, 4)); 345 346 // Compute packed tiles A & B 347 { 348 Tile4x4Packed packed_a = pack_4x4( 349 indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0); 350 Tile4x4Packed packed_b = pack_4x4( 351 indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4); 352 writePackedT(packed, p.packed_stride, packed_a, packed_b); 353 } 354 355 // Compute/store packed tiles A & B in transpose output 356 Tile4x4Packed packed_trans_a = pack_4x4( 357 indices.a, 358 Tile4x4Accessor(lines, 0, 0), 359 metadata.meta_ac_trans, 360 0, 361 true); 362 Tile4x4Packed packed_trans_b = pack_4x4( 363 indices.b, 364 Tile4x4Accessor(lines, 0, 4), 365 metadata.meta_bd_trans, 366 0, 367 true); 368 // (NOTE) Now we no longer need A & B (`lines[0:4]`) 369 370 // Load/process tiles `C` and `D` 371 CUTLASS_PRAGMA_UNROLL 372 for (int i = 4; i < 8; ++i) { 373 lines[i].fill(fillValue); 374 cutlass::arch::global_load<Fragment, sizeof(Fragment)>( 375 lines[i], input + i * p.input_s0, x + i < p.input_dim0); 376 } 377 indices.c = compute_tile_indices(Tile4x4Accessor(lines, 4, 0)); 378 indices.d = compute_tile_indices(Tile4x4Accessor(lines, 4, 4)); 379 380 // Compute packed tiles C & D 381 { 382 Tile4x4Packed packed_c = pack_4x4( 383 indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0); 384 Tile4x4Packed packed_d = pack_4x4( 385 indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4); 386 writePackedT( 387 packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d); 388 } 389 390 // Compute/store packed tiles C & D in transpose output 391 Tile4x4Packed packed_trans_c = pack_4x4( 392 indices.c, 393 Tile4x4Accessor(lines, 4, 0), 394 metadata.meta_ac_trans, 395 4, 396 true); 397 Tile4x4Packed packed_trans_d = pack_4x4( 398 indices.d, 399 Tile4x4Accessor(lines, 4, 4), 400 metadata.meta_bd_trans, 401 4, 402 true); 403 404 // Dump the metadata in a nice format 405 *p.getCurrentThreadIndices() = indices; 406 407 // Store packed A, B, C & D for transposed matrix 408 writePackedT( 409 packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c); 410 packed_trans += 4 * p.packed_trans_stride; 411 writePackedT( 412 packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d); 413 414 // Writing meta non-transposed 415 { 416 ElementInputE* packed_meta_reordered = metadata_gmem.get_metaN( 417 warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY); 418 warp_shuffle_and_write_meta(packed_meta_reordered, metadata.meta_ab); 419 warp_shuffle_and_write_meta(packed_meta_reordered + 32, metadata.meta_cd); 420 } 421 422 // Writing meta transposed 423 { 424 ElementInputE* packed_trans_meta_reordered = metadata_gmem.get_metaT( 425 warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY); 426 warp_shuffle_and_write_meta( 427 packed_trans_meta_reordered, metadata.meta_ac_trans, true); 428 warp_shuffle_and_write_meta( 429 packed_trans_meta_reordered + 32, metadata.meta_bd_trans, true); 430 } 431 } 432 sparse_semi_structured_apply_kernelKernelTypes433 CUTLASS_DEVICE static void sparse_semi_structured_apply_kernel(Params p) { 434 // See `sparse24_sparsify_both_ways_kernel` 435 // It's basically the same, just that we skip 436 // the part where compute the indices we keep 437 438 // Top-left of the 8x8 tile we own 439 int warp_x = blockIdx.x * kWarpX; 440 int warp_y = blockIdx.y * kWarpY; 441 int x = warp_x + threadIdx.x * kThreadX; 442 int y = warp_y + threadIdx.y * kThreadY; 443 444 Element const* input = p.input + x * p.input_s0 + y; 445 Element* packed = p.packed + x * p.packed_stride + (y / 2); 446 Element* packed_trans = 447 p.packed_trans + (x / 2) + y * p.packed_trans_stride; 448 449 Fragment lines[8]; // Contains all values from the 8x8 tile 450 451 Tile8x8Meta metadata; 452 Tile8x8Masks indices = *p.getCurrentThreadIndices(); 453 454 // Load/process tiles `A` and `B` 455 CUTLASS_PRAGMA_UNROLL 456 for (int i = 0; i < 8; ++i) { 457 // NB: Values outside bounds is undefined, but shouldn't 458 // be used anywhere 459 cutlass::arch::global_load<Fragment, sizeof(Fragment)>( 460 lines[i], input + i * p.input_s0, x + i < p.input_dim0); 461 } 462 463 // Compute packed tiles A & B 464 { 465 Tile4x4Packed packed_a = pack_4x4( 466 indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0); 467 Tile4x4Packed packed_b = pack_4x4( 468 indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4); 469 writePackedT(packed, p.packed_stride, packed_a, packed_b); 470 } 471 472 // Compute/store packed tiles A & B in transpose output 473 Tile4x4Packed packed_trans_a = pack_4x4( 474 indices.a, 475 Tile4x4Accessor(lines, 0, 0), 476 metadata.meta_ac_trans, 477 0, 478 true); 479 Tile4x4Packed packed_trans_b = pack_4x4( 480 indices.b, 481 Tile4x4Accessor(lines, 0, 4), 482 metadata.meta_bd_trans, 483 0, 484 true); 485 // (NOTE) Now we no longer need A & B (`lines[0:4]`) 486 487 // Compute packed tiles C & D 488 { 489 Tile4x4Packed packed_c = pack_4x4( 490 indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0); 491 Tile4x4Packed packed_d = pack_4x4( 492 indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4); 493 writePackedT( 494 packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d); 495 } 496 497 // Compute/store packed tiles C & D in transpose output 498 Tile4x4Packed packed_trans_c = pack_4x4( 499 indices.c, 500 Tile4x4Accessor(lines, 4, 0), 501 metadata.meta_ac_trans, 502 4, 503 true); 504 Tile4x4Packed packed_trans_d = pack_4x4( 505 indices.d, 506 Tile4x4Accessor(lines, 4, 4), 507 metadata.meta_bd_trans, 508 4, 509 true); 510 511 // Store packed A, B, C & D for transposed matrix 512 writePackedT( 513 packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c); 514 packed_trans += 4 * p.packed_trans_stride; 515 writePackedT( 516 packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d); 517 } 518 }; 519 520 } // namespace at::native 521