xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/sparse/cuda/SparseSemiStructuredPack.h>
4 #include <ATen/native/sparse/cuda/StaticSort.h>
5 #include <cutlass/bfloat16.h>
6 #include <cutlass/half.h>
7 
8 // Given 4x4 values, computes the selected indices that will remain after 2:4
9 // sparsification, as a bitmask.
10 // NOTE: Algorithms might select LESS than 8 values in total in some cases.
11 
12 namespace platform {
13 template <>
14 struct numeric_limits<cutlass::bfloat16_t> {
15   CUTLASS_HOST_DEVICE
16   static cutlass::bfloat16_t infinity() {
17     return cutlass::bfloat16_t::bitcast(0x7f80);
18   }
19 };
20 } // namespace platform
21 
22 namespace at::native{
23 
24 template <typename Element, typename Pointwise>
25 struct TileValueOrderedT {
26   union {
27     struct {
28       Element value;
29       uint2b_t col;
30       uint2b_t row;
31     } parts;
32     uint32_t raw;
33   };
34   CUTLASS_DEVICE bool operator<(
35       TileValueOrderedT<Element, Pointwise> const& other) const {
36     return Pointwise::apply(parts.value) < Pointwise::apply(other.parts.value);
37   }
38   CUTLASS_DEVICE TileValueOrderedT() {}
39 };
40 
41 // Operations that we can apply to rank the values
42 struct IdentityOp {
43   template <typename T>
44   static T CUTLASS_HOST_DEVICE apply(T const& x) {
45     return x;
46   }
47 };
48 // Can be applied to rank based on absolute value
49 struct AbsOp {
50   template <typename T>
51   static T CUTLASS_HOST_DEVICE apply(T const& x) {
52     return cutlass::abs(x);
53   }
54 };
55 
56 // Given 4x4 values, computes the selected indices that will remain after 2:4
57 // sparsification, as a bitmask. We have 2 constraints:
58 // (1) At most 2 values per line
59 // (2) At most 2 values per column
60 // This means we can select at most 8 values in total.
61 // ALGO: We use a greedy algorithm, where we take values in the 4x4
62 // tile in descending order. If a value fits (because the line/col is not
63 // already full), we select it. Then we move on to the next one.
64 // NOTE: This algorithm might select LESS than 8 values in total in some cases.
65 // NOTE (2): RF are not indexable, so we shouldn't rely on indexing
66 //   values at any point, otherwise they will be stored in local memory.
67 template <typename Op = IdentityOp>
68 struct LargestValuesGreedy {
69   template <typename T>
70   static CUTLASS_DEVICE T outOfBoundsFillValue() {
71     return -platform::numeric_limits<T>::infinity();
72   }
73 
74   template <typename Tile4x4Accessor>
75   CUTLASS_DEVICE Indices4x4 operator()(Tile4x4Accessor values) {
76     using TileValueOrdered =
77         TileValueOrderedT<typename Tile4x4Accessor::Element, Op>;
78     using TileValuesFragment = cutlass::Array<TileValueOrdered, 4 * 4>;
79     Indices4x4 indices;
80     TileValuesFragment values_ordered;
81     CUTLASS_PRAGMA_UNROLL
82     for (int i = 0; i < 4; ++i) {
83       CUTLASS_PRAGMA_UNROLL
84       for (int j = 0; j < 4; ++j) {
85         TileValueOrdered& v = values_ordered[i * 4 + j];
86         v.parts.value = values.at(i, j).get();
87         v.parts.col = uint2b_t(j);
88         v.parts.row = uint2b_t(i);
89       }
90     }
91     // Use a sorting network (aka without branches) to avoid
92     // warp divergence
93     StaticSort<TileValuesFragment::kElements> sorter;
94     sorter(values_ordered);
95 
96     // bitmask to store how many we have selected on a given row/col
97     // 0 selected: (numPerRow >> 2*row) = 00 (0)
98     // 1 selected: (numPerRow >> 2*row) = 01 (1)
99     // 2 selected: (numPerRow >> 2*row) = 11 (3)
100     uint32_t numPerRow = 0;
101     uint32_t numPerCol = 0;
102     indices = 0;
103 
104     // Take as many as we can, starting with the largest values
105     CUTLASS_PRAGMA_UNROLL
106     for (int i = values_ordered.size() - 1; i >= 0; i--) {
107       auto& e = values_ordered[i];
108 
109       uint32_t rcount = uint2b_t(numPerRow >> 2 * e.parts.row);
110       uint32_t ccount = uint2b_t(numPerCol >> 2 * e.parts.col);
111       // NOTE: This is more efficient (yet equivalent) to:
112       // `rcount != 3 && ccount != 3`
113       bool selected = (rcount + ccount) <= 2;
114       indices |= selected << (e.parts.col + 4 * e.parts.row);
115 
116       numPerRow |= (rcount + selected) << 2 * e.parts.row;
117       numPerCol |= (ccount + selected) << 2 * e.parts.col;
118     }
119     return indices;
120   }
121 };
122 
123 // We consider each rows independantly in order
124 // This is to ensure that a row's sparsity pattern is only determined
125 // by its values and the rows before (but never the rows after)
126 // This enforces causality strictly
127 template <typename Op = IdentityOp>
128 struct Causal1122 {
129   template <typename T>
130   static CUTLASS_DEVICE T outOfBoundsFillValue() {
131     return -platform::numeric_limits<T>::infinity();
132   }
133 
134   template <typename Tile4x4Accessor>
135   CUTLASS_DEVICE Indices4x4 operator()(Tile4x4Accessor values) {
136     static constexpr int kMaxValuesPerRow[] = {1, 1, 2, 2};
137     using TileValueOrdered =
138         TileValueOrderedT<typename Tile4x4Accessor::Element, Op>;
139     using TileValuesFragment = cutlass::Array<TileValueOrdered, 4>;
140     Indices4x4 indices = 0;
141 
142     uint32_t numPerCol = 0; // <- see doc in `LargestValuesGreedy`
143 
144     CUTLASS_PRAGMA_UNROLL
145     for (int row = 0; row < 4; ++row) {
146       int row_count = 0;
147       TileValuesFragment values_ordered;
148       CUTLASS_PRAGMA_UNROLL
149       for (int col = 0; col < 4; ++col) {
150         TileValueOrdered& v = values_ordered[col];
151         v.parts.value = values.at(row, col).get();
152         v.parts.col = uint2b_t(col);
153       }
154       // Use a sorting network (aka without branches) to avoid
155       // warp divergence
156       StaticSort<TileValuesFragment::kElements> sorter;
157       sorter(values_ordered);
158 
159       // Take as many as we can, starting with the largest values
160       CUTLASS_PRAGMA_UNROLL
161       for (int i = values_ordered.size() - 1; i >= 0; i--) {
162         auto& e = values_ordered[i];
163 
164         uint32_t ccount = uint2b_t(numPerCol >> 2 * e.parts.col);
165         bool selected = ccount != 3 && (row_count < kMaxValuesPerRow[row]);
166         indices |= selected << (e.parts.col + 4 * row);
167         numPerCol |= (ccount + selected) << 2 * e.parts.col;
168         row_count += selected;
169       }
170     }
171     return indices;
172   }
173 };
174 
175 template <typename T>
176 void named_algorithms(T callback) {
177   callback(LargestValuesGreedy<IdentityOp>(), "largest_values_greedy");
178   callback(Causal1122<IdentityOp>(), "causal1122");
179   callback(LargestValuesGreedy<AbsOp>(), "largest_abs_values_greedy");
180   // default one
181   callback(LargestValuesGreedy<IdentityOp>(), "");
182 }
183 
184 } // namespace
185