1 #pragma once
2
3 #include <ATen/SparseCsrTensorImpl.h>
4 #include <ATen/SparseTensorImpl.h>
5 #include <ATen/core/Tensor.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #include <ATen/Operators.h>
11 #else
12 #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
13 #include <ATen/ops/resize_as_sparse_native.h>
14 #endif
15
16 #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
17 [&] { \
18 const auto& the_layout = LAYOUT; \
19 switch (the_layout) { \
20 case kSparseCsr: \
21 case kSparseCsc: \
22 case kSparseBsr: \
23 case kSparseBsc: \
24 return __VA_ARGS__(); \
25 default: \
26 AT_ERROR( \
27 NAME, \
28 " expected sparse compressed tensor layout but got ", \
29 the_layout); \
30 } \
31 }()
32
33 #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
34 LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
35 [&]() { \
36 const auto& the_layout = LAYOUT; \
37 switch (the_layout) { \
38 case kSparseCsr: \
39 case kSparseBsr: \
40 return (ROW_DIM_ACTION)(); \
41 case kSparseCsc: \
42 case kSparseBsc: \
43 return (COLUMN_DIM_ACTION)(); \
44 default: \
45 AT_ERROR( \
46 NAME, \
47 " expected sparse compressed tensor layout but got ", \
48 the_layout); \
49 } \
50 }()
51
52 #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
53 LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
54 [&]() { \
55 const auto& the_layout = LAYOUT; \
56 switch (the_layout) { \
57 case kSparseCsr: \
58 case kSparseCsc: \
59 return (NO_BLOCK_ACTION)(); \
60 case kSparseBsr: \
61 case kSparseBsc: \
62 return (BLOCK_ACTION)(); \
63 default: \
64 AT_ERROR( \
65 NAME, \
66 " expected sparse compressed tensor layout but got ", \
67 the_layout); \
68 } \
69 }()
70
71 #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
72 LAYOUT, NAME, ROW_DIM_ACTION) \
73 [&]() { \
74 const auto& the_layout = LAYOUT; \
75 switch (the_layout) { \
76 case kSparseCsr: \
77 case kSparseBsr: \
78 return (ROW_DIM_ACTION)(); \
79 default: \
80 AT_ERROR( \
81 NAME, \
82 " expected sparse row compressed tensor layout but got ", \
83 the_layout); \
84 } \
85 }()
86
87 #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
88 LAYOUT, NAME, COL_DIM_ACTION) \
89 [&]() { \
90 const auto& the_layout = LAYOUT; \
91 switch (the_layout) { \
92 case kSparseCsc: \
93 case kSparseBsc: \
94 return (COL_DIM_ACTION)(); \
95 default: \
96 AT_ERROR( \
97 NAME, \
98 " expected sparse column compressed tensor layout but got ", \
99 the_layout); \
100 } \
101 }()
102
103 #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
104 [&]() { \
105 const auto& the_layout = LAYOUT; \
106 switch (the_layout) { \
107 case kSparseCsr: \
108 case kSparseCsc: \
109 return (ACTION)(); \
110 default: \
111 AT_ERROR( \
112 NAME, \
113 " expected sparse compressed (non-block) tensor layout but got ", \
114 the_layout); \
115 } \
116 }()
117
118 #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
119 [&]() { \
120 const auto& the_layout = LAYOUT; \
121 switch (the_layout) { \
122 case kSparseBsr: \
123 case kSparseBsc: \
124 return (ACTION)(); \
125 default: \
126 AT_ERROR( \
127 NAME, \
128 " expected sparse compressed block tensor layout but got ", \
129 the_layout); \
130 } \
131 }()
132
133 #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
134 AT_DISPATCH_SWITCH( \
135 TYPE, \
136 NAME, \
137 AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
138 kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
139
140 namespace at::sparse_csr {
141
142 // Implements RAII object to manage checking sparse tensor invariants:
143 class CheckSparseTensorInvariants {
144 bool old_state;
145
146 public:
CheckSparseTensorInvariants(bool state)147 CheckSparseTensorInvariants(bool state) {
148 old_state = at::globalContext().checkSparseTensorInvariants();
149 at::globalContext().setCheckSparseTensorInvariants(state);
150 }
151
~CheckSparseTensorInvariants()152 ~CheckSparseTensorInvariants() {
153 at::globalContext().setCheckSparseTensorInvariants(old_state);
154 }
155 };
156
157 using SparseCsrTensor = Tensor;
158
is_sparse_compressed(const Layout & layout)159 inline bool is_sparse_compressed(const Layout& layout) {
160 switch (layout) {
161 case kSparseCsr:
162 case kSparseCsc:
163 case kSparseBsr:
164 case kSparseBsc:
165 return true;
166 default:;
167 }
168 return false;
169 }
170
is_sparse_compressed(const Tensor & self)171 inline bool is_sparse_compressed(const Tensor& self) {
172 return is_sparse_compressed(self.layout());
173 }
174
get_sparse_csr_impl(const SparseCsrTensor & self)175 inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
176 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
177 self.layout(), "get_sparse_csr_impl", [&] {});
178 return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
179 }
180
181 inline std::string layoutToString(
182 Layout layout,
183 bool upper = false,
184 bool lower = false) {
185 switch (layout) {
186 case kSparseCsr:
187 return (upper ? "CSR" : (lower ? "csr" : "Csr"));
188 case kSparseCsc:
189 return (upper ? "CSC" : (lower ? "csc" : "Csc"));
190 case kSparseBsr:
191 return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
192 case kSparseBsc:
193 return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
194 default:
195 TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
196 return "";
197 }
198 }
199
isCompressedRow(Layout layout)200 inline bool isCompressedRow(Layout layout) {
201 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
202 layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
203 }
204
isCompressedColumn(Layout layout)205 inline bool isCompressedColumn(Layout layout) {
206 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
207 layout,
208 "isCompressedColumn",
209 [&] { return false; },
210 [&] { return true; });
211 }
212
compressedIndicesName(Layout layout)213 inline std::string compressedIndicesName(Layout layout) {
214 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
215 layout,
216 "compressedIndicesName",
217 [&] { return "crow_indices"; },
218 [&] { return "ccol_indices"; });
219 }
220
plainIndicesName(Layout layout)221 inline std::string plainIndicesName(Layout layout) {
222 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
223 layout,
224 "plainIndicesName",
225 [&] { return "col_indices"; },
226 [&] { return "row_indices"; });
227 }
228
compressedDimName(Layout layout)229 inline std::string compressedDimName(Layout layout) {
230 switch (layout) {
231 case kSparseCsr:
232 return "row";
233 case kSparseCsc:
234 return "column";
235 case kSparseBsr:
236 return "row block";
237 case kSparseBsc:
238 return "column block";
239 default:
240 TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
241 return "";
242 }
243 }
244
plainDimName(Layout layout)245 inline std::string plainDimName(Layout layout) {
246 switch (layout) {
247 case kSparseCsr:
248 return "column";
249 case kSparseCsc:
250 return "row";
251 case kSparseBsr:
252 return "column block";
253 case kSparseBsc:
254 return "row block";
255 default:
256 TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
257 return "";
258 }
259 }
260
rowDimension(Layout layout,IntArrayRef size)261 inline size_t rowDimension(Layout layout, IntArrayRef size) {
262 return size.size() - (isCompressedRow(layout) ? 2 : 1);
263 }
264
columnDimension(Layout layout,IntArrayRef size)265 inline size_t columnDimension(Layout layout, IntArrayRef size) {
266 return size.size() - (isCompressedColumn(layout) ? 2 : 1);
267 }
268
269 inline size_t compressedDimension(
270 Layout layout,
271 IntArrayRef size,
272 size_t dense_ndim = 0) {
273 return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
274 }
275
276 inline size_t plainDimension(
277 Layout layout,
278 IntArrayRef size,
279 size_t dense_ndim = 0) {
280 return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
281 }
282
numBatchDimensions(Tensor const & self)283 inline int64_t numBatchDimensions(Tensor const& self) {
284 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
285 self.layout(),
286 "numBatchDimensions",
287 [&self] { return self.crow_indices().dim() - 1; },
288 [&self] { return self.ccol_indices().dim() - 1; });
289 }
290
getCompressedPlainIndices(Tensor const & self)291 inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
292 return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
293 self.layout(),
294 "getCompressedPlainIndices",
295 [&self] {
296 return std::make_pair(self.crow_indices(), self.col_indices());
297 },
298 [&self] {
299 return std::make_pair(self.ccol_indices(), self.row_indices());
300 });
301 }
302
getIndexDtype(Tensor const & self)303 inline ScalarType getIndexDtype(Tensor const& self) {
304 switch (self.layout()) {
305 case kSparseCsr:
306 case kSparseBsr:
307 return self.crow_indices().scalar_type();
308 case kSparseCsc:
309 case kSparseBsc:
310 return self.ccol_indices().scalar_type();
311 case kSparse:
312 return self._indices().scalar_type();
313 default:
314 return ScalarType::Long;
315 }
316 }
317
flip_compressed_layout(Layout layout)318 inline Layout flip_compressed_layout(Layout layout) {
319 switch (layout) {
320 case kSparseCsr:
321 return kSparseCsc;
322 case kSparseCsc:
323 return kSparseCsr;
324 case kSparseBsr:
325 return kSparseBsc;
326 case kSparseBsc:
327 return kSparseBsr;
328 default:
329 TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
330 return kSparseCsr;
331 }
332 }
333
getBlockSize(Tensor const & self)334 inline DimVector getBlockSize(Tensor const& self) {
335 int64_t n_batch = numBatchDimensions(self);
336 return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
337 }
338
getSymIntBlockSize(Tensor const & self)339 inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
340 if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
341 int64_t n_batch = numBatchDimensions(self);
342 return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
343 } else {
344 return {};
345 }
346 }
347
348 template <typename binary_op_t, typename binary_op_out_t>
only_sparse_compressed_binary_op_trivial_cases(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & out,const binary_op_t & binary_op,const binary_op_out_t & binary_op_out)349 inline bool only_sparse_compressed_binary_op_trivial_cases(
350 const Tensor& self,
351 const Tensor& other,
352 const Scalar& alpha,
353 Tensor& out,
354 const binary_op_t& binary_op,
355 const binary_op_out_t& binary_op_out) {
356 // Only sparse compressed! Just like the name says :)
357 TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
358 TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
359 TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
360
361 // Bypass BLAS if there are matches in (self, other, out)
362 if (self.is_same(out) && self.is_same(other)) {
363 binary_op_out(self.values(), other.values(), alpha);
364 return true;
365 }
366 if (self.is_same(other)) {
367 auto [compressed_indices, plain_indices] =
368 at::sparse_csr::getCompressedPlainIndices(self);
369 static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
370 ->set_member_tensors(
371 compressed_indices,
372 plain_indices,
373 binary_op(self.values(), other.values(), alpha),
374 self.sizes());
375 return true;
376 }
377 return false;
378 }
379
only_sparse_compressed_add_trivial_cases(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & out)380 inline bool only_sparse_compressed_add_trivial_cases(
381 const Tensor& self,
382 const Tensor& other,
383 const Scalar& alpha,
384 Tensor& out) {
385 return only_sparse_compressed_binary_op_trivial_cases(
386 self,
387 other,
388 alpha,
389 out,
390 [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
391 return v1.add(v2, alpha);
392 },
393 [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
394 return v1.add_(v2, alpha);
395 });
396 }
397
to_type(const Tensor & input,ScalarType dtype)398 inline Tensor to_type(const Tensor& input, ScalarType dtype) {
399 auto [compressed_indices, plain_indices] =
400 at::sparse_csr::getCompressedPlainIndices(input);
401 return at::_sparse_compressed_tensor_unsafe(
402 compressed_indices,
403 plain_indices,
404 std::move(input.values()).to(dtype),
405 input.sizes(),
406 dtype,
407 input.layout(),
408 input.device(),
409 input.options().pinned_memory_opt());
410 }
411
412 template <typename acc_t, typename scalar_t>
413 inline std::tuple<Tensor, Tensor> create_acc_buffer(
414 TensorOptions option,
415 ScalarType type,
416 int64_t nnz = -1) {
417 Tensor new_values, new_values_acc;
418 constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
419 bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
420 if constexpr (need_acc) {
421 auto acc_dtype = CppTypeToScalarType<acc_t>::value;
422 new_values_acc = at::empty({}, option.dtype(acc_dtype));
423 new_values = is_integral ? new_values_acc : at::empty({}, option);
424 } else {
425 new_values = new_values_acc = at::empty({}, option);
426 }
427 if (nnz != -1) {
428 return std::make_tuple(
429 new_values.resize_(nnz), new_values_acc.resize_(nnz));
430 } else {
431 return std::make_tuple(new_values, new_values_acc);
432 }
433 }
434
copy_from_acc_buffer(Tensor & new_values,Tensor & new_values_acc)435 inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
436 if (!new_values_acc.is_same(new_values)) {
437 new_values.copy_(new_values_acc);
438 }
439 }
440
441 } // namespace at::sparse_csr
442