xref: /aosp_15_r20/external/pytorch/aten/src/ATen/SparseCsrTensorUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/SparseCsrTensorImpl.h>
4 #include <ATen/SparseTensorImpl.h>
5 #include <ATen/core/Tensor.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #include <ATen/Operators.h>
11 #else
12 #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
13 #include <ATen/ops/resize_as_sparse_native.h>
14 #endif
15 
16 #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
17   [&] {                                                              \
18     const auto& the_layout = LAYOUT;                                 \
19     switch (the_layout) {                                            \
20       case kSparseCsr:                                               \
21       case kSparseCsc:                                               \
22       case kSparseBsr:                                               \
23       case kSparseBsc:                                               \
24         return __VA_ARGS__();                                        \
25       default:                                                       \
26         AT_ERROR(                                                    \
27             NAME,                                                    \
28             " expected sparse compressed tensor layout but got ",    \
29             the_layout);                                             \
30     }                                                                \
31   }()
32 
33 #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(                \
34     LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION)              \
35   [&]() {                                                         \
36     const auto& the_layout = LAYOUT;                              \
37     switch (the_layout) {                                         \
38       case kSparseCsr:                                            \
39       case kSparseBsr:                                            \
40         return (ROW_DIM_ACTION)();                                \
41       case kSparseCsc:                                            \
42       case kSparseBsc:                                            \
43         return (COLUMN_DIM_ACTION)();                             \
44       default:                                                    \
45         AT_ERROR(                                                 \
46             NAME,                                                 \
47             " expected sparse compressed tensor layout but got ", \
48             the_layout);                                          \
49     }                                                             \
50   }()
51 
52 #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(              \
53     LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION)                  \
54   [&]() {                                                         \
55     const auto& the_layout = LAYOUT;                              \
56     switch (the_layout) {                                         \
57       case kSparseCsr:                                            \
58       case kSparseCsc:                                            \
59         return (NO_BLOCK_ACTION)();                               \
60       case kSparseBsr:                                            \
61       case kSparseBsc:                                            \
62         return (BLOCK_ACTION)();                                  \
63       default:                                                    \
64         AT_ERROR(                                                 \
65             NAME,                                                 \
66             " expected sparse compressed tensor layout but got ", \
67             the_layout);                                          \
68     }                                                             \
69   }()
70 
71 #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS(                    \
72     LAYOUT, NAME, ROW_DIM_ACTION)                                     \
73   [&]() {                                                             \
74     const auto& the_layout = LAYOUT;                                  \
75     switch (the_layout) {                                             \
76       case kSparseCsr:                                                \
77       case kSparseBsr:                                                \
78         return (ROW_DIM_ACTION)();                                    \
79       default:                                                        \
80         AT_ERROR(                                                     \
81             NAME,                                                     \
82             " expected sparse row compressed tensor layout but got ", \
83             the_layout);                                              \
84     }                                                                 \
85   }()
86 
87 #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS(                       \
88     LAYOUT, NAME, COL_DIM_ACTION)                                        \
89   [&]() {                                                                \
90     const auto& the_layout = LAYOUT;                                     \
91     switch (the_layout) {                                                \
92       case kSparseCsc:                                                   \
93       case kSparseBsc:                                                   \
94         return (COL_DIM_ACTION)();                                       \
95       default:                                                           \
96         AT_ERROR(                                                        \
97             NAME,                                                        \
98             " expected sparse column compressed tensor layout but got ", \
99             the_layout);                                                 \
100     }                                                                    \
101   }()
102 
103 #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION)  \
104   [&]() {                                                                     \
105     const auto& the_layout = LAYOUT;                                          \
106     switch (the_layout) {                                                     \
107       case kSparseCsr:                                                        \
108       case kSparseCsc:                                                        \
109         return (ACTION)();                                                    \
110       default:                                                                \
111         AT_ERROR(                                                             \
112             NAME,                                                             \
113             " expected sparse compressed (non-block) tensor layout but got ", \
114             the_layout);                                                      \
115     }                                                                         \
116   }()
117 
118 #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
119   [&]() {                                                                 \
120     const auto& the_layout = LAYOUT;                                      \
121     switch (the_layout) {                                                 \
122       case kSparseBsr:                                                    \
123       case kSparseBsc:                                                    \
124         return (ACTION)();                                                \
125       default:                                                            \
126         AT_ERROR(                                                         \
127             NAME,                                                         \
128             " expected sparse compressed block tensor layout but got ",   \
129             the_layout);                                                  \
130     }                                                                     \
131   }()
132 
133 #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
134   AT_DISPATCH_SWITCH(                                   \
135       TYPE,                                             \
136       NAME,                                             \
137       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4(      \
138           kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
139 
140 namespace at::sparse_csr {
141 
142 // Implements RAII object to manage checking sparse tensor invariants:
143 class CheckSparseTensorInvariants {
144   bool old_state;
145 
146  public:
CheckSparseTensorInvariants(bool state)147   CheckSparseTensorInvariants(bool state) {
148     old_state = at::globalContext().checkSparseTensorInvariants();
149     at::globalContext().setCheckSparseTensorInvariants(state);
150   }
151 
~CheckSparseTensorInvariants()152   ~CheckSparseTensorInvariants() {
153     at::globalContext().setCheckSparseTensorInvariants(old_state);
154   }
155 };
156 
157 using SparseCsrTensor = Tensor;
158 
is_sparse_compressed(const Layout & layout)159 inline bool is_sparse_compressed(const Layout& layout) {
160   switch (layout) {
161     case kSparseCsr:
162     case kSparseCsc:
163     case kSparseBsr:
164     case kSparseBsc:
165       return true;
166     default:;
167   }
168   return false;
169 }
170 
is_sparse_compressed(const Tensor & self)171 inline bool is_sparse_compressed(const Tensor& self) {
172   return is_sparse_compressed(self.layout());
173 }
174 
get_sparse_csr_impl(const SparseCsrTensor & self)175 inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
176   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
177       self.layout(), "get_sparse_csr_impl", [&] {});
178   return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
179 }
180 
181 inline std::string layoutToString(
182     Layout layout,
183     bool upper = false,
184     bool lower = false) {
185   switch (layout) {
186     case kSparseCsr:
187       return (upper ? "CSR" : (lower ? "csr" : "Csr"));
188     case kSparseCsc:
189       return (upper ? "CSC" : (lower ? "csc" : "Csc"));
190     case kSparseBsr:
191       return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
192     case kSparseBsc:
193       return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
194     default:
195       TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
196       return "";
197   }
198 }
199 
isCompressedRow(Layout layout)200 inline bool isCompressedRow(Layout layout) {
201   return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
202       layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
203 }
204 
isCompressedColumn(Layout layout)205 inline bool isCompressedColumn(Layout layout) {
206   return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
207       layout,
208       "isCompressedColumn",
209       [&] { return false; },
210       [&] { return true; });
211 }
212 
compressedIndicesName(Layout layout)213 inline std::string compressedIndicesName(Layout layout) {
214   return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
215       layout,
216       "compressedIndicesName",
217       [&] { return "crow_indices"; },
218       [&] { return "ccol_indices"; });
219 }
220 
plainIndicesName(Layout layout)221 inline std::string plainIndicesName(Layout layout) {
222   return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
223       layout,
224       "plainIndicesName",
225       [&] { return "col_indices"; },
226       [&] { return "row_indices"; });
227 }
228 
compressedDimName(Layout layout)229 inline std::string compressedDimName(Layout layout) {
230   switch (layout) {
231     case kSparseCsr:
232       return "row";
233     case kSparseCsc:
234       return "column";
235     case kSparseBsr:
236       return "row block";
237     case kSparseBsc:
238       return "column block";
239     default:
240       TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
241       return "";
242   }
243 }
244 
plainDimName(Layout layout)245 inline std::string plainDimName(Layout layout) {
246   switch (layout) {
247     case kSparseCsr:
248       return "column";
249     case kSparseCsc:
250       return "row";
251     case kSparseBsr:
252       return "column block";
253     case kSparseBsc:
254       return "row block";
255     default:
256       TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
257       return "";
258   }
259 }
260 
rowDimension(Layout layout,IntArrayRef size)261 inline size_t rowDimension(Layout layout, IntArrayRef size) {
262   return size.size() - (isCompressedRow(layout) ? 2 : 1);
263 }
264 
columnDimension(Layout layout,IntArrayRef size)265 inline size_t columnDimension(Layout layout, IntArrayRef size) {
266   return size.size() - (isCompressedColumn(layout) ? 2 : 1);
267 }
268 
269 inline size_t compressedDimension(
270     Layout layout,
271     IntArrayRef size,
272     size_t dense_ndim = 0) {
273   return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
274 }
275 
276 inline size_t plainDimension(
277     Layout layout,
278     IntArrayRef size,
279     size_t dense_ndim = 0) {
280   return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
281 }
282 
numBatchDimensions(Tensor const & self)283 inline int64_t numBatchDimensions(Tensor const& self) {
284   return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
285       self.layout(),
286       "numBatchDimensions",
287       [&self] { return self.crow_indices().dim() - 1; },
288       [&self] { return self.ccol_indices().dim() - 1; });
289 }
290 
getCompressedPlainIndices(Tensor const & self)291 inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
292   return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
293       self.layout(),
294       "getCompressedPlainIndices",
295       [&self] {
296         return std::make_pair(self.crow_indices(), self.col_indices());
297       },
298       [&self] {
299         return std::make_pair(self.ccol_indices(), self.row_indices());
300       });
301 }
302 
getIndexDtype(Tensor const & self)303 inline ScalarType getIndexDtype(Tensor const& self) {
304   switch (self.layout()) {
305     case kSparseCsr:
306     case kSparseBsr:
307       return self.crow_indices().scalar_type();
308     case kSparseCsc:
309     case kSparseBsc:
310       return self.ccol_indices().scalar_type();
311     case kSparse:
312       return self._indices().scalar_type();
313     default:
314       return ScalarType::Long;
315   }
316 }
317 
flip_compressed_layout(Layout layout)318 inline Layout flip_compressed_layout(Layout layout) {
319   switch (layout) {
320     case kSparseCsr:
321       return kSparseCsc;
322     case kSparseCsc:
323       return kSparseCsr;
324     case kSparseBsr:
325       return kSparseBsc;
326     case kSparseBsc:
327       return kSparseBsr;
328     default:
329       TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
330       return kSparseCsr;
331   }
332 }
333 
getBlockSize(Tensor const & self)334 inline DimVector getBlockSize(Tensor const& self) {
335   int64_t n_batch = numBatchDimensions(self);
336   return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
337 }
338 
getSymIntBlockSize(Tensor const & self)339 inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
340   if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
341     int64_t n_batch = numBatchDimensions(self);
342     return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
343   } else {
344     return {};
345   }
346 }
347 
348 template <typename binary_op_t, typename binary_op_out_t>
only_sparse_compressed_binary_op_trivial_cases(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & out,const binary_op_t & binary_op,const binary_op_out_t & binary_op_out)349 inline bool only_sparse_compressed_binary_op_trivial_cases(
350     const Tensor& self,
351     const Tensor& other,
352     const Scalar& alpha,
353     Tensor& out,
354     const binary_op_t& binary_op,
355     const binary_op_out_t& binary_op_out) {
356   // Only sparse compressed! Just like the name says :)
357   TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
358   TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
359   TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
360 
361   // Bypass BLAS if there are matches in (self, other, out)
362   if (self.is_same(out) && self.is_same(other)) {
363     binary_op_out(self.values(), other.values(), alpha);
364     return true;
365   }
366   if (self.is_same(other)) {
367     auto [compressed_indices, plain_indices] =
368         at::sparse_csr::getCompressedPlainIndices(self);
369     static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
370         ->set_member_tensors(
371             compressed_indices,
372             plain_indices,
373             binary_op(self.values(), other.values(), alpha),
374             self.sizes());
375     return true;
376   }
377   return false;
378 }
379 
only_sparse_compressed_add_trivial_cases(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & out)380 inline bool only_sparse_compressed_add_trivial_cases(
381     const Tensor& self,
382     const Tensor& other,
383     const Scalar& alpha,
384     Tensor& out) {
385   return only_sparse_compressed_binary_op_trivial_cases(
386       self,
387       other,
388       alpha,
389       out,
390       [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
391         return v1.add(v2, alpha);
392       },
393       [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
394         return v1.add_(v2, alpha);
395       });
396 }
397 
to_type(const Tensor & input,ScalarType dtype)398 inline Tensor to_type(const Tensor& input, ScalarType dtype) {
399   auto [compressed_indices, plain_indices] =
400       at::sparse_csr::getCompressedPlainIndices(input);
401   return at::_sparse_compressed_tensor_unsafe(
402       compressed_indices,
403       plain_indices,
404       std::move(input.values()).to(dtype),
405       input.sizes(),
406       dtype,
407       input.layout(),
408       input.device(),
409       input.options().pinned_memory_opt());
410 }
411 
412 template <typename acc_t, typename scalar_t>
413 inline std::tuple<Tensor, Tensor> create_acc_buffer(
414     TensorOptions option,
415     ScalarType type,
416     int64_t nnz = -1) {
417   Tensor new_values, new_values_acc;
418   constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
419   bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
420   if constexpr (need_acc) {
421     auto acc_dtype = CppTypeToScalarType<acc_t>::value;
422     new_values_acc = at::empty({}, option.dtype(acc_dtype));
423     new_values = is_integral ? new_values_acc : at::empty({}, option);
424   } else {
425     new_values = new_values_acc = at::empty({}, option);
426   }
427   if (nnz != -1) {
428     return std::make_tuple(
429         new_values.resize_(nnz), new_values_acc.resize_(nnz));
430   } else {
431     return std::make_tuple(new_values, new_values_acc);
432   }
433 }
434 
copy_from_acc_buffer(Tensor & new_values,Tensor & new_values_acc)435 inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
436   if (!new_values_acc.is_same(new_values)) {
437     new_values.copy_(new_values_acc);
438   }
439 }
440 
441 } // namespace at::sparse_csr
442