1 #pragma once 2 3 #include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h> 4 #include <ATen/native/sparse/cuda/StaticSort.h> 5 #include <cutlass/bfloat16.h> 6 #include <cutlass/half.h> 7 8 // Given 4x4 values, computes the selected indices that will remain after 2:4 9 // sparsification, as a bitmask. 10 // NOTE: Algorithms might select LESS than 8 values in total in some cases. 11 12 namespace platform { 13 template <> 14 struct numeric_limits<cutlass::bfloat16_t> { 15 CUTLASS_HOST_DEVICE 16 static cutlass::bfloat16_t infinity() { 17 return cutlass::bfloat16_t::bitcast(0x7f80); 18 } 19 }; 20 } // namespace platform 21 22 namespace at::native{ 23 24 template <typename Element, typename Pointwise> 25 struct TileValueOrderedT { 26 union { 27 struct { 28 Element value; 29 uint2b_t col; 30 uint2b_t row; 31 } parts; 32 uint32_t raw; 33 }; 34 CUTLASS_DEVICE bool operator<( 35 TileValueOrderedT<Element, Pointwise> const& other) const { 36 return Pointwise::apply(parts.value) < Pointwise::apply(other.parts.value); 37 } 38 CUTLASS_DEVICE TileValueOrderedT() {} 39 }; 40 41 // Operations that we can apply to rank the values 42 struct IdentityOp { 43 template <typename T> 44 static T CUTLASS_HOST_DEVICE apply(T const& x) { 45 return x; 46 } 47 }; 48 // Can be applied to rank based on absolute value 49 struct AbsOp { 50 template <typename T> 51 static T CUTLASS_HOST_DEVICE apply(T const& x) { 52 return cutlass::abs(x); 53 } 54 }; 55 56 // Given 4x4 values, computes the selected indices that will remain after 2:4 57 // sparsification, as a bitmask. We have 2 constraints: 58 // (1) At most 2 values per line 59 // (2) At most 2 values per column 60 // This means we can select at most 8 values in total. 61 // ALGO: We use a greedy algorithm, where we take values in the 4x4 62 // tile in descending order. If a value fits (because the line/col is not 63 // already full), we select it. Then we move on to the next one. 64 // NOTE: This algorithm might select LESS than 8 values in total in some cases. 65 // NOTE (2): RF are not indexable, so we shouldn't rely on indexing 66 // values at any point, otherwise they will be stored in local memory. 67 template <typename Op = IdentityOp> 68 struct LargestValuesGreedy { 69 template <typename T> 70 static CUTLASS_DEVICE T outOfBoundsFillValue() { 71 return -platform::numeric_limits<T>::infinity(); 72 } 73 74 template <typename Tile4x4Accessor> 75 CUTLASS_DEVICE Indices4x4 operator()(Tile4x4Accessor values) { 76 using TileValueOrdered = 77 TileValueOrderedT<typename Tile4x4Accessor::Element, Op>; 78 using TileValuesFragment = cutlass::Array<TileValueOrdered, 4 * 4>; 79 Indices4x4 indices; 80 TileValuesFragment values_ordered; 81 CUTLASS_PRAGMA_UNROLL 82 for (int i = 0; i < 4; ++i) { 83 CUTLASS_PRAGMA_UNROLL 84 for (int j = 0; j < 4; ++j) { 85 TileValueOrdered& v = values_ordered[i * 4 + j]; 86 v.parts.value = values.at(i, j).get(); 87 v.parts.col = uint2b_t(j); 88 v.parts.row = uint2b_t(i); 89 } 90 } 91 // Use a sorting network (aka without branches) to avoid 92 // warp divergence 93 StaticSort<TileValuesFragment::kElements> sorter; 94 sorter(values_ordered); 95 96 // bitmask to store how many we have selected on a given row/col 97 // 0 selected: (numPerRow >> 2*row) = 00 (0) 98 // 1 selected: (numPerRow >> 2*row) = 01 (1) 99 // 2 selected: (numPerRow >> 2*row) = 11 (3) 100 uint32_t numPerRow = 0; 101 uint32_t numPerCol = 0; 102 indices = 0; 103 104 // Take as many as we can, starting with the largest values 105 CUTLASS_PRAGMA_UNROLL 106 for (int i = values_ordered.size() - 1; i >= 0; i--) { 107 auto& e = values_ordered[i]; 108 109 uint32_t rcount = uint2b_t(numPerRow >> 2 * e.parts.row); 110 uint32_t ccount = uint2b_t(numPerCol >> 2 * e.parts.col); 111 // NOTE: This is more efficient (yet equivalent) to: 112 // `rcount != 3 && ccount != 3` 113 bool selected = (rcount + ccount) <= 2; 114 indices |= selected << (e.parts.col + 4 * e.parts.row); 115 116 numPerRow |= (rcount + selected) << 2 * e.parts.row; 117 numPerCol |= (ccount + selected) << 2 * e.parts.col; 118 } 119 return indices; 120 } 121 }; 122 123 // We consider each rows independantly in order 124 // This is to ensure that a row's sparsity pattern is only determined 125 // by its values and the rows before (but never the rows after) 126 // This enforces causality strictly 127 template <typename Op = IdentityOp> 128 struct Causal1122 { 129 template <typename T> 130 static CUTLASS_DEVICE T outOfBoundsFillValue() { 131 return -platform::numeric_limits<T>::infinity(); 132 } 133 134 template <typename Tile4x4Accessor> 135 CUTLASS_DEVICE Indices4x4 operator()(Tile4x4Accessor values) { 136 static constexpr int kMaxValuesPerRow[] = {1, 1, 2, 2}; 137 using TileValueOrdered = 138 TileValueOrderedT<typename Tile4x4Accessor::Element, Op>; 139 using TileValuesFragment = cutlass::Array<TileValueOrdered, 4>; 140 Indices4x4 indices = 0; 141 142 uint32_t numPerCol = 0; // <- see doc in `LargestValuesGreedy` 143 144 CUTLASS_PRAGMA_UNROLL 145 for (int row = 0; row < 4; ++row) { 146 int row_count = 0; 147 TileValuesFragment values_ordered; 148 CUTLASS_PRAGMA_UNROLL 149 for (int col = 0; col < 4; ++col) { 150 TileValueOrdered& v = values_ordered[col]; 151 v.parts.value = values.at(row, col).get(); 152 v.parts.col = uint2b_t(col); 153 } 154 // Use a sorting network (aka without branches) to avoid 155 // warp divergence 156 StaticSort<TileValuesFragment::kElements> sorter; 157 sorter(values_ordered); 158 159 // Take as many as we can, starting with the largest values 160 CUTLASS_PRAGMA_UNROLL 161 for (int i = values_ordered.size() - 1; i >= 0; i--) { 162 auto& e = values_ordered[i]; 163 164 uint32_t ccount = uint2b_t(numPerCol >> 2 * e.parts.col); 165 bool selected = ccount != 3 && (row_count < kMaxValuesPerRow[row]); 166 indices |= selected << (e.parts.col + 4 * row); 167 numPerCol |= (ccount + selected) << 2 * e.parts.col; 168 row_count += selected; 169 } 170 } 171 return indices; 172 } 173 }; 174 175 template <typename T> 176 void named_algorithms(T callback) { 177 callback(LargestValuesGreedy<IdentityOp>(), "largest_values_greedy"); 178 callback(Causal1122<IdentityOp>(), "causal1122"); 179 callback(LargestValuesGreedy<AbsOp>(), "largest_abs_values_greedy"); 180 // default one 181 callback(LargestValuesGreedy<IdentityOp>(), ""); 182 } 183 184 } // namespace 185