1 // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/ATen.h>
3 #include <ATen/core/Tensor.h>
4 #include <optional>
5 #include <ATen/quantized/Quantizer.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/Parallel.h>
8 #include <ATen/TensorOperators.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_autocast_to_full_precision_native.h>
15 #include <ATen/ops/_autocast_to_reduced_precision_native.h>
16 #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
17 #include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
18 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
19 #include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
20 #include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
21 #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
22 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
23 #include <ATen/ops/_sparse_coo_tensor_with_dims_native.h>
24 #include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
25 #include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
26 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
27 #include <ATen/ops/_to_copy.h>
28 #include <ATen/ops/_to_copy_native.h>
29 #include <ATen/ops/_to_cpu_native.h>
30 #include <ATen/ops/_to_dense_native.h>
31 #include <ATen/ops/_to_sparse_bsc_native.h>
32 #include <ATen/ops/_to_sparse_bsr_native.h>
33 #include <ATen/ops/_to_sparse_csc_native.h>
34 #include <ATen/ops/_to_sparse_csr_native.h>
35 #include <ATen/ops/_to_sparse_native.h>
36 #include <ATen/ops/arange_native.h>
37 #include <ATen/ops/empty.h>
38 #include <ATen/ops/empty_like.h>
39 #include <ATen/ops/empty_quantized.h>
40 #include <ATen/ops/empty_strided.h>
41 #include <ATen/ops/empty_strided_native.h>
42 #include <ATen/ops/to_dense_backward_native.h>
43 #include <ATen/ops/to_dense_native.h>
44 #include <ATen/ops/to_mkldnn_backward_native.h>
45 #include <ATen/ops/to_native.h>
46 #include <ATen/ops/to_sparse_bsc_native.h>
47 #include <ATen/ops/to_sparse_bsr_native.h>
48 #include <ATen/ops/to_sparse_csc_native.h>
49 #include <ATen/ops/to_sparse_csr_native.h>
50 #include <ATen/ops/to_sparse_native.h>
51 #include <ATen/ops/view_native.h>
52 #include <ATen/ops/zeros.h>
53 #endif
54
55 #include <ATen/SparseCsrTensorUtils.h>
56 #include <ATen/core/ATen_fwd.h>
57 #include <ATen/native/IndexingUtils.h>
58 #include <ATen/native/NonSymbolicBC.h>
59 #include <ATen/native/SparseTensorUtils.h>
60 #include <ATen/native/TensorConversions.h>
61 #include <c10/core/impl/DeviceGuardImplInterface.h>
62 #include <algorithm>
63 #include <numeric>
64
65 namespace at::native {
66
67 namespace {
68 // dense_to_sparse_{csr,bsr,csc,bsc} common helpers
69
70 // Preparation fo the N-D dense -> sparse compressed conversion.
71 // The N-D input is converted to 3-D (single batch dim) where we check that the
72 // product of batch dims is nonzero and for each batch the sparse matrix
73 // contained within has the same number of non-zero elements.
74 // The batches are joined along the compressed axis. The generation of indices
75 // for this matrix can be performed in a single step followed by a single step
76 // conversion to restore the batch dimension.
dense_to_sparse_compressed_prepare_check_mask_values_batched(const Layout & target_layout,Tensor & values,Tensor & mask,const int64_t & n_batch_dim)77 void dense_to_sparse_compressed_prepare_check_mask_values_batched(
78 const Layout& target_layout,
79 Tensor& values,
80 Tensor& mask,
81 const int64_t& n_batch_dim) {
82 if (n_batch_dim > 1) {
83 // For inputs with more than 1 batch dim we flatten them out.
84 // Input shape (b0, b1 ..., bn, r, c) -> (b0 * b1 * ... * bn, r ,c)
85 values = values.flatten(0, n_batch_dim - 1);
86 mask = mask.flatten(0, n_batch_dim - 1);
87 }
88
89 // For informative messaging form the name of the function
90 // to_sparse_{csr,csc,bsr,bsc}.
91 TORCH_CHECK(
92 mask.size(0) > 0,
93 "to_sparse_",
94 // We want the message to match the function name so generate the
95 // lowercase acronym for the layout
96 sparse_csr::layoutToString(target_layout, false, true),
97 ": Expected product of batch dimensions to be non-zero.");
98
99 // Compute the number of non-zero elements in the first batch, expand to full
100 // size
101 auto nse_per_batch = mask.select(0, 0).sum().expand(mask.size(0));
102 TORCH_CHECK(
103 mask.sum({-2, -1}).equal(nse_per_batch),
104 "Expect the same number of specified elements per batch.");
105
106 // We need to join batches into a matrix increasing the length of the
107 // compressed axis. This allows us to create indices for a compressed matrix
108 // and de-batch them later (two kernels). Otherwise we would have to create
109 // indices for each batch individually requiring n_batch kernels. For csr/bsr,
110 // we already have the batch dim adjacent to the compressed axis and can
111 // flatten them together. For csc/bsc, we need to transpose first.
112 // For BSR/CSR (b, r, c) -> (b*r, c)
113 // For BSC/CSC (b, c, r) -> (r, b*c)
114 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
115 target_layout,
116 "dense_to_sparse_compressed",
117 [&]() {
118 values = values.flatten(0, 1);
119 mask = mask.flatten(0, 1);
120 },
121 [&]() {
122 values = values.transpose(0, 1).flatten(1, 2);
123 mask = mask.transpose(0, 1).flatten(1, 2);
124 });
125 }
126
127 // This function unfolds the compressed indices of a compressed sparse matrix
128 // into a batched compressed sparse tensor.
129 // This is analogous to an unflatten-like operation:
130 // unflatten(0, {b, r}) for csr/bsr with input shape (r*b, c)
131 // (output shape (b, r, c))
132 // unflatten(1, {b, c}).transpose(0,1) for csc/bsc with input shape (r, c*b)
133 // (output shape (r, b, c) unflatten, (b, r, c) unflatten + transpose)
134 // This only operates on the compressed indices as the plain indices and values
135 // can be manipulated as described above without special handling.
136 // It is a prerequisite for the conversion that the sparsity pattern is sane for
137 // the batched shape. That is each batch has the same number of nonzero
138 // elements.
compressed_to_batched_compressed_indices(const Tensor & compressed_in,const int64_t & n_batch,bool out_int32)139 Tensor compressed_to_batched_compressed_indices(
140 const Tensor& compressed_in,
141 const int64_t& n_batch,
142 bool out_int32) {
143 auto n_compressed_per_batch = (compressed_in.size(0) - 1) / n_batch;
144 ScalarType out_type = out_int32 ? ScalarType::Int : ScalarType::Long;
145 auto batched_out = at::zeros(
146 {n_batch, n_compressed_per_batch + 1},
147 compressed_in.options().dtype(out_type));
148
149 // If the compressed dimension has length zero there is 1 element in each
150 // batch and it is zero we already have this result formed
151 if (n_compressed_per_batch > 0) {
152 // Slice the compressed indices ignoring the leading 0 element and reshape
153 // to n-batch rows
154 auto trailing_slice =
155 compressed_in.slice(0, 1, std::nullopt, 1).reshape({n_batch, -1});
156 // Slice the compressed indices again selecting the elements corresponding
157 // to the batch boundary. The values here will be increasing multiples of
158 // nnz per batch. Reshape to n-batch rows (1 col) for broadcasting.
159 // This is equivalent to arange(n_batch) * nnz_per_batch with the same
160 // reshape
161 auto offsets = compressed_in.slice(0, 0, -1, n_compressed_per_batch)
162 .reshape({n_batch, -1});
163 // Subtracting the offsets from each row of the reshaped compressed indices
164 // gives us the compressed indices within the batch. The leading element of
165 // each row is not computed as it is always zero. We copy into the view on
166 // the output buffer.
167 batched_out.narrow(-1, 1, n_compressed_per_batch)
168 .copy_(trailing_slice - offsets);
169 }
170 return batched_out;
171 }
172
173 // After generating member tensors for sparse_compressed matrix, if the target
174 // shape is N-D we must reform the batch dimensions.
175 // Single kernel is used to restore one batch dimension in the compressed
176 // indices. From there full batch shape is restored by reshape. No special
177 // handling is needed for restoring batch dimensions of the values or
178 // plain_indices it can be done with reshape/unflatten.
reshape_2d_sparse_compressed_members_to_nd_batched(const IntArrayRef full_sizes,const int64_t & n_batch_dim,Tensor & compressed_indices,Tensor & plain_indices,Tensor & values)179 void reshape_2d_sparse_compressed_members_to_nd_batched(
180 const IntArrayRef full_sizes,
181 const int64_t& n_batch_dim,
182 Tensor& compressed_indices,
183 Tensor& plain_indices,
184 Tensor& values) {
185 auto batch_shape = full_sizes.slice(0, n_batch_dim);
186 auto n_batch = std::accumulate(
187 batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int64_t>());
188 // NOTE: using this conversion requires the nnz per batch is the same for all
189 // batches that will be formed. We ensured this was the case on the way in so
190 // it is safe to use this conversion.
191 compressed_indices = compressed_to_batched_compressed_indices(
192 compressed_indices, n_batch, /*out_int32*/ false);
193
194 // We can infer the last dim of the reshape targets, it will be nnz or
195 // nrow/ncol+1 depending on the layout and member tensor targeted.
196 auto batchsize_infer_last = DimVector(batch_shape);
197 batchsize_infer_last.push_back(-1);
198
199 // -1 will be nnz per batch
200 plain_indices = plain_indices.reshape(batchsize_infer_last);
201 // -1 will be ncols (bsc,csc) or nrows (bsr,csr) + 1
202 compressed_indices = compressed_indices.reshape(batchsize_infer_last);
203 // -1 will be nnz (per batch).
204 // Note: Unflatten rather than reshape as it will work
205 // for both blocked and unblocked layouts. reshape works for unblocked layouts
206 // only
207 values = values.unflatten(0, batchsize_infer_last);
208 }
209 } // namespace
210
211 // Take a Device that may not have device_index set (i.e., having it as -1
212 // representing the current device) and return the corresponding Device
213 // according to the actual device at the time of this function call. No-op
214 // if the device_index is set.
ensure_has_index(Device device)215 static inline Device ensure_has_index(Device device) {
216 if (device.is_cpu() || device.has_index()) {
217 return device;
218 }
219 const c10::impl::DeviceGuardImplInterface* impl = c10::impl::getDeviceGuardImpl(device.type());
220 return impl->getDevice();
221 }
222
ensure_has_index(std::optional<Device> device)223 static inline std::optional<Device> ensure_has_index(std::optional<Device> device) {
224 if (!device.has_value()) {
225 return std::nullopt;
226 }
227 return ensure_has_index(device.value());
228 }
229
_to_copy(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<c10::MemoryFormat> optional_memory_format)230 Tensor _to_copy(
231 const Tensor& self,
232 std::optional<ScalarType> dtype,
233 std::optional<Layout> layout,
234 std::optional<Device> device,
235 std::optional<bool> pin_memory,
236 bool non_blocking,
237 std::optional<c10::MemoryFormat> optional_memory_format) {
238 TORCH_CHECK(!layout.has_value() || self.layout() == layout.value(),
239 "to(options) doesn't support converting to a different layout, "
240 "but got self.layout being ", self.layout(),
241 " and options.layout set as ", layout.value());
242 auto options = TensorOptions()
243 .dtype(dtype)
244 .layout(layout)
245 .device(device)
246 .pinned_memory(pin_memory);
247
248 if (options.has_device()) {
249 options = options.device(ensure_has_index(options.device()));
250 }
251 // memory_format is handled separately due to MemoryFormat::Preserve logic
252 options = self.options().merge_in(options).memory_format(std::nullopt);
253 auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
254
255 // TODO: Use the dispatcher for this.
256 // Currently there are unenumerated extensibility issues preventing this.
257 if (self.layout() == kSparse) {
258 TORCH_CHECK(
259 memory_format == MemoryFormat::Preserve,
260 "to(options): COO only supports memory format Preserve, but got ", memory_format,
261 " instead.");
262 if (options.device().is_meta()) {
263 return zeros_like(self, options);
264 }
265 auto indices = self._indices();
266 const auto new_indices = at::native::to(
267 indices,
268 indices.scalar_type(),
269 c10::kStrided,
270 device,
271 pin_memory,
272 non_blocking,
273 true, // force copy since we are in _to_copy
274 memory_format);
275 const auto new_values = at::native::to(
276 self._values(),
277 dtype,
278 c10::kStrided,
279 device,
280 pin_memory,
281 non_blocking,
282 true, // force copy since we are in _to_copy
283 memory_format);
284
285 return at::_sparse_coo_tensor_unsafe(
286 new_indices,
287 new_values,
288 self.sizes(),
289 options, self.is_coalesced());
290 } else if (at::sparse_csr::is_sparse_compressed(self)) {
291 TORCH_CHECK(
292 memory_format == MemoryFormat::Preserve,
293 "to(options): ", at::sparse_csr::layoutToString(self.layout()),
294 " only supports memory format Preserve, but got ", memory_format,
295 " instead.");
296
297 if (options.device().is_meta()) {
298 return zeros_like(self, options);
299 }
300
301 auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(self);
302
303 const auto new_values = at::native::to(
304 self.values(),
305 dtype,
306 c10::kStrided,
307 device,
308 pin_memory,
309 non_blocking,
310 true, // force copy since we are in _to_copy
311 memory_format);
312
313 const auto new_compressed_indices = at::native::to(
314 compressed_indices,
315 compressed_indices.scalar_type(),
316 c10::kStrided,
317 device,
318 pin_memory,
319 non_blocking,
320 true, // force copy since we are in _to_copy
321 memory_format);
322
323 const auto new_plain_indices = at::native::to(
324 plain_indices,
325 plain_indices.scalar_type(),
326 c10::kStrided,
327 device,
328 pin_memory,
329 non_blocking,
330 true, // force copy since we are in _to_copy
331 memory_format);
332
333 return at::_sparse_compressed_tensor_unsafe(
334 new_compressed_indices,
335 new_plain_indices,
336 new_values,
337 self.sizes(),
338 options);
339 }
340
341 bool pin_out = (non_blocking && (self.is_cuda() || self.is_privateuseone())
342 && options.device().is_cpu() && (options.layout() == c10::kStrided));
343
344 if (memory_format == MemoryFormat::Preserve) {
345 if (options.device().supports_as_strided()) {
346 if (self.is_non_overlapping_and_dense()) {
347 Tensor r;
348 if (self.is_quantized()) {
349 r = at::empty_quantized(self.sizes(), self, options);
350 at::QuantizerPtr quantizer = r.quantizer();
351 r.copy_(self, non_blocking);
352 set_quantizer_(r, quantizer);
353 } else {
354 r = at::empty_strided(
355 self.sizes(),
356 self.strides(),
357 options.pinned_memory(pin_out));
358 r.copy_(self, non_blocking);
359 }
360 return r;
361 } else if (!self.is_quantized() && self.layout() == kStrided) {
362 Tensor r;
363 auto strides = infer_dense_strides(self.sizes(), self.strides());
364 r = at::empty_strided(
365 self.sizes(),
366 strides,
367 options.pinned_memory(pin_out));
368 r.copy_(self, non_blocking);
369 return r;
370 } else {
371 memory_format = self.suggest_memory_format();
372 }
373 } else {
374 memory_format = self.suggest_memory_format();
375 }
376 }
377 // See Note [Explicit nullopt MemoryFormat argument]
378 // TODO: empty_quantized does not work here. It raises an exception in CheckMemoryFormat.h prior to
379 // empty_affine_quantized/_empty_per_channel_affine_quantized calls
380 // at::empty also does not work here because there is no proper at::empty support for quantized tensors
381 // as it would return a quantized tensor with an UnknownQuantizer
382 auto r = self.is_quantized() ? at::empty_like(self, memory_format)
383 : at::empty_symint(self.sym_sizes(),
384 options.memory_format(memory_format).pinned_memory(pin_out), std::nullopt);
385 r.copy_(self, non_blocking);
386 return r;
387 }
388
389 template <typename T>
is_null_or_equal_to(const std::optional<T> & test,const T & value)390 static inline bool is_null_or_equal_to(const std::optional<T>& test, const T& value) {
391 if (!test.has_value()) {
392 return true;
393 }
394 return test.value() == value;
395 }
396
397 // NOTE: static runtime's to_maybe_copy_out relies on details of this
398 // check; if you change how it works, please update static runtime as
399 // well.
to_will_alias(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)400 bool to_will_alias(
401 const Tensor& self,
402 std::optional<ScalarType> dtype,
403 std::optional<Layout> layout,
404 std::optional<Device> device,
405 bool copy,
406 std::optional<c10::MemoryFormat> optional_memory_format) {
407 auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
408
409 return is_null_or_equal_to(dtype, self.dtype().toScalarType()) &&
410 is_null_or_equal_to(layout, self.layout()) &&
411 is_null_or_equal_to(device, self.device()) &&
412 !copy &&
413 (memory_format == MemoryFormat::Preserve ||
414 self.suggest_memory_format() == memory_format);
415 }
416
to_impl(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)417 static inline Tensor to_impl(
418 const Tensor& self,
419 std::optional<ScalarType> dtype,
420 std::optional<Layout> layout,
421 std::optional<Device> device,
422 std::optional<bool> pin_memory,
423 bool non_blocking,
424 bool copy,
425 std::optional<c10::MemoryFormat> optional_memory_format) {
426
427 // fast path
428 if (to_will_alias(self, dtype, layout, device, copy, optional_memory_format)) {
429 return self;
430 }
431 return at::_to_copy(
432 self, dtype, layout, device, pin_memory, non_blocking, optional_memory_format);
433 }
434
435 // If input tensor is fp32, cast it to fp16, otherwise leave it alone.
436 // (this is intended to be used internally by the JIT autocast implementation)
_autocast_to_reduced_precision(const Tensor & self,bool cuda_enabled,bool cpu_enabled,ScalarType cuda_dtype,ScalarType cpu_dtype)437 Tensor _autocast_to_reduced_precision(const Tensor& self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) {
438 if (self.dtype() == at::ScalarType::Float &&
439 ((self.device().is_cuda() && cuda_enabled) ||
440 (self.device().is_cpu() && cpu_enabled))
441 ) {
442 at::ScalarType target = at::ScalarType::Undefined;
443 if (self.device().is_cuda()) {
444 target = cuda_dtype;
445 } else if (self.device().is_cpu()) {
446 target = cpu_dtype;
447 }
448
449 TORCH_INTERNAL_ASSERT(target != at::ScalarType::Undefined, "_autocast_to_reduced_precision requires legit ScalarType argument for given device");
450
451 return to_impl(
452 self, target, std::nullopt, std::nullopt, std::nullopt, false, false, std::nullopt);
453 } else {
454 return self;
455 }
456 }
457
458 // If input tensor is fp16, cast it to fp32, otherwise leave it alone.
459 // (this is intended to be used internally by the JIT autocast implementation)
_autocast_to_full_precision(const Tensor & self,bool cuda_enabled,bool cpu_enabled)460 Tensor _autocast_to_full_precision(const Tensor& self, bool cuda_enabled, bool cpu_enabled) {
461 if ((self.dtype() == at::ScalarType::Half || self.dtype() == at::ScalarType::BFloat16) &&
462 ((self.device().is_cuda() && cuda_enabled) ||
463 (self.device().is_cpu() && cpu_enabled))
464 ) {
465 return to_impl(
466 self, at::ScalarType::Float, std::nullopt, std::nullopt, std::nullopt, false, false, std::nullopt);
467 } else {
468 return self;
469 }
470 }
471
to(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)472 Tensor to(
473 const Tensor& self,
474 std::optional<ScalarType> dtype,
475 std::optional<Layout> layout,
476 std::optional<Device> device,
477 std::optional<bool> pin_memory,
478 bool non_blocking,
479 bool copy,
480 std::optional<c10::MemoryFormat> optional_memory_format
481 ) {
482 return to_impl(
483 self,
484 dtype,
485 layout,
486 ensure_has_index(device),
487 pin_memory,
488 non_blocking,
489 copy,
490 optional_memory_format);
491 }
492
to(const Tensor & self,Device device,ScalarType dtype,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)493 Tensor to(const Tensor& self, Device device, ScalarType dtype, bool non_blocking, bool copy, std::optional<c10::MemoryFormat> optional_memory_format) {
494 return to_impl(
495 self,
496 dtype,
497 std::nullopt,
498 ensure_has_index(device),
499 std::nullopt,
500 non_blocking,
501 copy,
502 optional_memory_format);
503 }
504
to(const Tensor & self,ScalarType dtype,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)505 Tensor to(const Tensor& self, ScalarType dtype, bool non_blocking, bool copy, std::optional<c10::MemoryFormat> optional_memory_format) {
506 return to_impl(
507 self,
508 dtype,
509 std::nullopt,
510 std::nullopt,
511 std::nullopt,
512 non_blocking,
513 copy,
514 optional_memory_format);
515 }
516
to(const Tensor & self,const Tensor & other,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)517 Tensor to(const Tensor& self, const Tensor& other, bool non_blocking, bool copy, std::optional<c10::MemoryFormat> optional_memory_format) {
518 auto options = other.options();
519 return to_impl(
520 self,
521 options.dtype().toScalarType(),
522 options.layout(),
523 options.device(),
524 options.pinned_memory(),
525 non_blocking,
526 copy,
527 optional_memory_format);
528 }
529
530 // This op is important primarily for lazy / graph-based backends.
531 // While this vanilla implementation loops through each tensor and independently converts it to cpu,
532 // a lazy backend like XLA might need to tell sync updates across tensors.
_to_cpu(TensorList tensors)533 std::vector<Tensor> _to_cpu(TensorList tensors) {
534 std::vector<Tensor> cpu_tensors;
535 for (const auto& t : tensors) {
536 cpu_tensors.push_back(t.cpu());
537 }
538 return cpu_tensors;
539 }
540
to_dense_backward(const Tensor & grad,const Tensor & input_,std::optional<bool> masked_grad_)541 Tensor to_dense_backward(const Tensor& grad, const Tensor& input_, std::optional<bool> masked_grad_) {
542 /*
543 For historical reasons, to_dense backward implements masked
544 semantics for sparse tensors, that is, gradients with respect to
545 unspecified elements are ignored. The masked_grad kw argument of
546 to_dense is introduced to allow to_dense to be used in the
547 non-masked semantics context. However, for BC reasons, the default
548 value to masked_grad kw argument is set True as a first instance.
549 Eventually, we should eliminate the masked_grad kw argument and
550 let to_dense backward to behave according to non-masked
551 semantics. Masked semantics of tensors is implemented in the
552 framework of masked tensors.
553 */
554 const auto input_layout = input_.layout();
555 const bool masked_grad = masked_grad_.value_or(true);
556 switch (input_layout) {
557 case kStrided:
558 // TODO: return grad as it is
559 return grad.to_dense(input_.scalar_type(), masked_grad_);
560 case kSparse:
561 // Autograd operates on the coalesced assumption, i.e. no duplicate values.
562 if (masked_grad) {
563 return grad.sparse_mask(input_.coalesce());
564 } else {
565 // TODO: return grad as it is
566 return grad.to_sparse(input_.sparse_dim());
567 }
568 case kSparseCsr:
569 case kSparseCsc:
570 // TODO: add efficient CSR/CSC support for sparse_mask
571 if (masked_grad) {
572 return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout);
573 } else {
574 // TODO: return grad as it is
575 return grad.to_sparse(input_layout, /*blocksize=*/std::nullopt, /*dense_dim=*/input_.dense_dim());
576 }
577 case kSparseBsr:
578 case kSparseBsc: {
579 // TODO: add efficient BSR/BSC support for sparse_mask
580 const auto blocksize = at::sparse_csr::getBlockSize(input_);
581 if (masked_grad) {
582 return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout, blocksize);
583 } else {
584 // TODO: return grad as it is
585 return grad.to_sparse(input_layout, blocksize, input_.dense_dim());
586 }
587 }
588 case kMkldnn:
589 return grad.to_mkldnn(input_.scalar_type());
590 default:
591 AT_ERROR("to_dense_backward: Unsupported input layout: ", input_layout);
592 return Tensor{};
593 }
594 }
595
to_mkldnn_backward(const Tensor & grad,const Tensor & input_)596 Tensor to_mkldnn_backward(const Tensor& grad, const Tensor& input_) {
597 AT_ASSERT(input_.layout() == c10::kStrided);
598 return grad.to_dense(input_.scalar_type());
599 }
600
to_dense(const Tensor & tensor,std::optional<c10::ScalarType> dtype,std::optional<bool> masked_grad)601 Tensor to_dense(const Tensor& tensor, std::optional<c10::ScalarType> dtype, std::optional<bool> masked_grad) {
602 if (tensor.layout() == c10::kSparse) {
603 return tensor._to_dense(dtype, masked_grad);
604 }
605 if (tensor.layout() == c10::kSparseCsr ||
606 tensor.layout() == c10::kSparseCsc ||
607 tensor.layout() == c10::kSparseBsr ||
608 tensor.layout() == c10::kSparseBsc) {
609 return tensor._to_dense(dtype, masked_grad);
610 }
611 if (tensor.layout() == c10::kMkldnn) {
612 return tensor._to_dense(dtype, masked_grad);
613 }
614 TORCH_CHECK(
615 tensor.layout() == c10::kStrided,
616 "to_dense does not support layout ",
617 tensor.layout());
618 if (dtype) {
619 return tensor.to(*dtype);
620 }
621 return tensor;
622 }
623
sparse_to_dense(const Tensor & self,std::optional<ScalarType> dtype,std::optional<bool> masked)624 Tensor sparse_to_dense(const Tensor& self, std::optional<ScalarType> dtype, std::optional<bool> masked) {
625 TORCH_CHECK(
626 !dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
627 Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
628 return dst.add_(self);
629 }
630
sparse_compressed_to_dense(const Tensor & self,std::optional<ScalarType> dtype,std::optional<bool> masked_grad)631 Tensor sparse_compressed_to_dense(
632 const Tensor& self,
633 std::optional<ScalarType> dtype,
634 std::optional<bool> masked_grad) {
635 TORCH_CHECK(
636 !dtype.has_value(),
637 "dtype argument is not supported by sparse_csr_to_dense");
638
639 if (self.numel() == 0) {
640 return at::zeros(self.sizes(), self.options().layout(kStrided));
641 }
642
643 auto batch_ndim = sparse_csr::numBatchDimensions(self);
644
645 auto compressed_rows = self.layout() == kSparseCsr || self.layout() == kSparseBsr;
646 auto block_sparse = self.layout() == kSparseBsr || self.layout() == kSparseBsc;
647
648 auto [compressed_indices, plain_indices] =
649 sparse_csr::getCompressedPlainIndices(self);
650
651 auto values = self.values();
652 Tensor dense = at::zeros(self.sizes(), self.options().layout(kStrided));
653
654 if (batch_ndim == 0) {
655 // Pad shape so we can treat non-batched like batched, we will
656 // squeeze out the phantom batch dim at the end.
657 compressed_indices.unsqueeze_(0);
658 plain_indices.unsqueeze_(0);
659 values.unsqueeze_(0);
660 dense.unsqueeze_(0);
661 }
662 if (batch_ndim > 1) {
663 // Flatten batch dims
664 compressed_indices = compressed_indices.flatten(0, batch_ndim - 1);
665 plain_indices = plain_indices.flatten(0, batch_ndim - 1);
666 values = values.flatten(0, batch_ndim - 1);
667 dense = dense.flatten(0, batch_ndim - 1);
668 }
669
670 // At this point there is only one batch dim, existed already or was
671 // flattened from multiple batch dims. Now, reshape the resulting
672 // dense matrix so that this single batch dim is joined with sparse
673 // dims into a single dim, so that the remaining dims are only block
674 // dims eventually, and then dense dims.
675 auto n_batch = values.size(0);
676 int64_t nrows = 0, ncols = 0;
677 auto dense_reshaped_sizes = dense.sizes().vec();
678 if (!block_sparse) {
679 nrows = self.size(batch_ndim);
680 ncols = self.size(batch_ndim + 1);
681 dense_reshaped_sizes.erase(dense_reshaped_sizes.begin(), dense_reshaped_sizes.begin() + 2);
682 } else {
683 std::array<int64_t, 2> blocksize = {values.size(2), values.size(3)};
684 nrows = self.size(batch_ndim) / blocksize[0];
685 ncols = self.size(batch_ndim + 1) / blocksize[1];
686 dense_reshaped_sizes[1] = blocksize[0];
687 dense_reshaped_sizes[2] = blocksize[1];
688 }
689 dense_reshaped_sizes[0] = n_batch * nrows * ncols;
690 dense = dense.reshape(dense_reshaped_sizes);
691
692 // Calculate batch, row and column indices for non-zeros in the
693 // sparse matrix, and use these to calculate corresponding indices
694 // into the dense matrix reshaped as above. Then, update dense
695 // matrix by adding sparse matrix values into elements with indices
696 // calculated this way.
697 auto options = compressed_indices.options();
698 auto nnz_per_batch = values.size(1);
699 auto batch_indices = at::arange(0, n_batch, options).repeat_interleave(nnz_per_batch);
700 auto ncompressed = compressed_rows ? nrows : ncols;
701 auto compressed_indices_over_all_batches =
702 at::cat({compressed_indices.slice(1, 0, ncompressed).flatten()
703 + nnz_per_batch * at::arange(0, n_batch, options).repeat_interleave(ncompressed),
704 n_batch * nnz_per_batch * at::ones({1}, options)});
705 Tensor indices = at::_convert_indices_from_csr_to_coo(
706 compressed_indices_over_all_batches,
707 plain_indices.flatten(),
708 false,
709 !compressed_rows);
710 auto row_indices = indices.select(0, 0);
711 auto col_indices = indices.select(0, 1);
712 if (compressed_rows) {
713 row_indices -= batch_indices * nrows;
714 } else {
715 col_indices -= batch_indices * ncols;
716 }
717 auto offsets = col_indices + row_indices * ncols + batch_indices * nrows * ncols;
718 dense.index_add_(0, offsets, values.flatten(0, 1));
719
720 // Un-tile the result. The final reshape uses the original
721 // self.sizes() which will squeeze out the extra batch dim if we put
722 // one in.
723 if (!block_sparse) {
724 return dense.reshape(self.sizes());
725 } else {
726 return dense
727 .unflatten(0, {-1, nrows, ncols})
728 .transpose(2, 3)
729 .reshape(self.sizes());
730 }
731 }
732
733 // Computes the strides for view_dtype output when the view dtype is
734 // smaller than the original dtype
compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides,int64_t size_ratio,ScalarType old_dtype,ScalarType new_dtype)735 inline SymDimVector compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
736 const int64_t ndim = old_strides.size();
737
738 TORCH_CHECK(
739 old_strides[ndim - 1] == 1,
740 "self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
741 " (different element sizes), but got ", old_strides[ndim - 1]);
742
743 SymDimVector new_strides(ndim);
744 for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
745 new_strides[dim_idx] = old_strides[dim_idx] * size_ratio;
746 }
747 new_strides[ndim - 1] = 1;
748 return new_strides;
749 }
750
751 // Computes the strides for view_dtype output when the view dtype is
752 // larger than the original dtype
compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides,int64_t size_ratio,ScalarType old_dtype,ScalarType new_dtype)753 inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
754 const int64_t ndim = old_strides.size();
755 TORCH_CHECK(
756 old_strides[ndim - 1] == 1,
757 "self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
758 " (different element sizes), but got ", old_strides[ndim - 1]);
759
760 SymDimVector new_strides(ndim);
761 for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
762 TORCH_CHECK(
763 (old_strides[dim_idx] % size_ratio) == 0,
764 "self.stride(", dim_idx, ") must be divisible by ", size_ratio,
765 " to view ", old_dtype, " as ", new_dtype, " (different element sizes), ",
766 "but got ", old_strides[dim_idx]);
767
768 new_strides[dim_idx] = old_strides[dim_idx] / size_ratio;
769 }
770 new_strides[ndim - 1] = 1;
771 return new_strides;
772 }
773
view_dtype(const Tensor & self,ScalarType dtype)774 Tensor view_dtype(const Tensor& self, ScalarType dtype) {
775 if (self.scalar_type() == dtype) {
776 return self;
777 }
778 const auto type_meta = c10::scalarTypeToTypeMeta(dtype);
779 TORCH_CHECK(!self.is_conj(),
780 "torch.Tensor.view is not supported for conjugate view tensors when converting to a different dtype.");
781 TORCH_CHECK(!self.is_neg(),
782 "torch.Tensor.view is not supported for tensors with negative bit set when converting to a different dtype.");
783
784 int64_t self_element_size = self.element_size();
785 int64_t new_element_size = static_cast<int64_t>(type_meta.itemsize());
786
787 Storage storage = self.storage();
788 auto new_tensor = detail::make_tensor<TensorImpl>(
789 std::move(storage), self.key_set(), type_meta);
790 auto* impl = new_tensor.unsafeGetTensorImpl();
791
792 if (self_element_size == new_element_size) {
793 impl->set_sizes_and_strides(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
794
795 } else if (self.dim() == 0) {
796 TORCH_CHECK(false,
797 "self.dim() cannot be 0 to view ", self.scalar_type(), " as ",
798 dtype, " (different element sizes)");
799
800 } else if (self_element_size > new_element_size) {
801 // Downsizing element size
802
803 int64_t size_ratio = self_element_size / new_element_size;
804 auto new_strides = compute_strides_for_view_dtype_downsize(
805 self.sym_strides(), size_ratio, self.scalar_type(), dtype);
806
807 auto old_sizes = self.sym_sizes();
808 SymDimVector new_sizes(self.dim());
809 std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
810 new_sizes[self.dim() - 1] *= size_ratio;
811
812 auto new_storage_offset = size_ratio * self.sym_storage_offset();
813
814 impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
815
816 } else {
817 // Upsizing element size
818
819 int64_t size_ratio = new_element_size / self_element_size;
820
821 TORCH_CHECK(
822 (self.sym_size(-1) % size_ratio) == 0,
823 "self.size(-1) must be divisible by ", size_ratio, " to view ",
824 self.scalar_type(), " as ", dtype, " (different element sizes), ",
825 "but got ", self.sym_size(-1));
826
827 TORCH_CHECK(
828 (self.sym_storage_offset() % size_ratio) == 0,
829 "self.storage_offset() must be divisible by ", size_ratio, " to view ",
830 self.scalar_type(), " as ", dtype, " (different element sizes), but got ",
831 self.sym_storage_offset());
832
833 auto new_strides = compute_strides_for_view_dtype_upsize(
834 self.sym_strides(), size_ratio, self.scalar_type(), dtype);
835
836 auto old_sizes = self.sym_sizes();
837 SymDimVector new_sizes(self.dim());
838 std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
839 new_sizes[self.dim() - 1] /= size_ratio;
840
841 auto new_storage_offset = self.sym_storage_offset() / size_ratio;
842
843 impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
844 }
845
846 return new_tensor;
847 }
848
_tile_tensor(const Tensor & self,IntArrayRef blocksize)849 static Tensor _tile_tensor(const Tensor& self, IntArrayRef blocksize) {
850 // This code turns a matrix into a sequence of blocks
851 //
852 // Given matrix
853 //
854 // 1 2 3 4
855 // 5 6 7 8
856 // 9 10 11 12
857 // 14 15 16 17
858 //
859 // _tile_tensor(matrix, {2, 2}) will yield the following 2 by 2 blocks
860 //
861 // 1 2 | 3 4 | 9 10 | 11 12
862 // 5 6 | 7 8 | 14 15 | 16 17
863 //
864 // via a 4D Tensor of shape (2, 2, 2, 2)
865 //
866 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[0] > 0);
867 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[1] > 0);
868 auto block_size_0 = self.size(0) / blocksize[0];
869 auto block_size_1 = self.size(1) / blocksize[1];
870
871 auto new_shape = DimVector({block_size_0, blocksize[0], block_size_1, blocksize[1]});
872 new_shape.append(DimVector(self.sizes().slice(2, self.dim() - 2)));
873 return self.reshape(new_shape)
874 .transpose(1, 2)
875 .contiguous();
876 }
877
_batch_tile_tensor(const Tensor & self,IntArrayRef blocksize,const int64_t dense_dim)878 static Tensor _batch_tile_tensor(const Tensor& self, IntArrayRef blocksize, const int64_t dense_dim) {
879 if (self.dim() == 2 + dense_dim) {
880 return _tile_tensor(self, blocksize);
881 }
882 auto n_batch_dim = self.dim() - 2 - dense_dim;
883 // Same as _tile_tensor, just per matrix entry of self, if self is 3D.
884 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[0] > 0);
885 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[1] > 0);
886 auto block_size_0 = self.size(n_batch_dim) / blocksize[0];
887 auto block_size_1 = self.size(n_batch_dim + 1) / blocksize[1];
888 auto tiled_sizes = DimVector(self.sizes().slice(0, n_batch_dim));
889 tiled_sizes.push_back(block_size_0);
890 tiled_sizes.push_back(blocksize[0]);
891 tiled_sizes.push_back(block_size_1);
892 tiled_sizes.push_back(blocksize[1]);
893 tiled_sizes.append(DimVector(self.sizes().slice(n_batch_dim + 2, dense_dim)));
894 return self.reshape(tiled_sizes).transpose(n_batch_dim + 1, n_batch_dim + 2).contiguous();
895 }
896
_mask_to_indices(const Tensor & mask)897 static Tensor _mask_to_indices(const Tensor& mask) {
898 // This function returns a vector of the indices at which given
899 // boolean mask is True. at::nonzero can achieve the same, but
900 // we yet have to compare the performance difference.
901 TORCH_CHECK(mask.dim() == 1, "Currently _mask_to_indices only supports 1-d masks.");
902 TORCH_CHECK(mask.dtype() == at::kBool, "Expected mask to be of dtype bool.");
903 return at::native::arange(
904 mask.numel(), at::kLong, kStrided, mask.device())
905 .masked_select(mask);
906 }
907
_not_zero_mask_to_col_row_indices(Tensor not_zero_mask,ScalarType index_dtype,Device index_device)908 static std::pair<Tensor, Tensor> _not_zero_mask_to_col_row_indices(
909 Tensor not_zero_mask,
910 ScalarType index_dtype,
911 Device index_device) {
912 auto col_indices =
913 at::native::arange(not_zero_mask.size(-1), index_dtype, kStrided, index_device)
914 .view({1, not_zero_mask.size(-1)})
915 .expand_as(not_zero_mask)
916 .masked_select(not_zero_mask);
917 auto row_indices =
918 at::native::arange(
919 not_zero_mask.size(-2), index_dtype, kStrided, index_device)
920 .view({not_zero_mask.size(-2), 1})
921 .expand_as(not_zero_mask)
922 .masked_select(not_zero_mask);
923 return std::pair<Tensor, Tensor>(col_indices, row_indices);
924 }
925
926 // Sparse layout conversions Start
927
928 static inline
_to_sparse_check_arguments(const std::string & funcname,const Tensor & self,const int64_t sparse_dim)929 void _to_sparse_check_arguments(const std::string& funcname, const Tensor& self, const int64_t sparse_dim) {
930 auto layout_from = self.layout();
931
932 auto layout_from_valid = layout_from == kStrided || layout_from == kSparse || at::sparse_csr::is_sparse_compressed(layout_from);
933 if (!layout_from_valid) {
934 AT_ERROR(funcname, ": unexpected source layout ", layout_from);
935 }
936
937 if (layout_from == kStrided) {
938 if (sparse_dim == 0 && self.dim() > 0) {
939 AT_ERROR(funcname, ": sparse_dim argument must be in >0 when self.dim()>0");
940 }
941 if (sparse_dim < 0 || sparse_dim > self.dim()) {
942 AT_ERROR(funcname, ": sparse_dim argument must be in [0,", self.dim(), "] range, but ", sparse_dim, " is given");
943 }
944 } else if (layout_from == kSparse) {
945 if (sparse_dim != self.sparse_dim()) {
946 AT_ERROR(funcname, ": conversion from ", layout_from, " to ", kSparse, " with sparse_dim argument !=self.sparse_dim() is not supported");
947 }
948 } else if (at::sparse_csr::is_sparse_compressed(layout_from)) {
949 if (sparse_dim != 2) {
950 AT_ERROR(funcname, ": conversion from ", layout_from, " to ", kSparse, " with sparse_dim argument !=2 is not supported");
951 }
952 }
953 }
954
955 static inline
_to_sparse_check_arguments(const std::string & funcname,const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)956 void _to_sparse_check_arguments(const std::string& funcname, const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
957 auto layout_from = self.layout();
958 auto layout_to = layout.value_or(kSparse);
959
960 auto layout_from_valid = layout_from == kStrided || layout_from == kSparse || at::sparse_csr::is_sparse_compressed(layout_from);
961 if (!layout_from_valid) {
962 AT_ERROR(funcname, ": unexpected source layout ", layout_from);
963 }
964 auto layout_to_valid = layout_to == kStrided || layout_to == kSparse || at::sparse_csr::is_sparse_compressed(layout_to);
965 if (!layout_to_valid) {
966 AT_ERROR(funcname, ": unexpected source layout ", layout_from);
967 }
968
969 if (layout_from == kSparse && layout_to != kSparse) {
970 if (self.sparse_dim() != 2) {
971 AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " for input tensors with sparse_dim()!=2 is not supported");
972 }
973 }
974
975 if ((layout_from == kSparseCsr || layout_from == kSparseCsc) &&
976 (layout_to == kSparseBsr || layout_to == kSparseBsc)) {
977 if (sparse_csr::numBatchDimensions(self) > 0) {
978 AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " for batched inputs is not supported");
979 }
980 }
981
982 if (blocksize.has_value()) {
983 if (blocksize.value().size() != 2) {
984 AT_ERROR(funcname, ": blocksize needs to be a tuple of size 2, but got ", blocksize.value().size());
985 }
986 auto blocksize_to = *blocksize;
987 if (blocksize_to[0] <= 0 || blocksize_to[1] <= 0) {
988 AT_ERROR(funcname, ": blocksize needs to be positive, but got ", blocksize_to);
989 }
990
991 if (layout_to == kSparseBsr || layout_to == kSparseBsc) {
992 if (layout_from == kSparseBsr || layout_from == kSparseBsc) {
993 auto blocksize_from = at::sparse_csr::getBlockSize(self);
994 if (!(blocksize_to == blocksize_from)) {
995 AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with blocksize changed from ", blocksize_from, " to ", blocksize_to, " is not supported");
996 }
997 } else {
998 auto dense_dim = (layout_from == kStrided) ? dense_dim_opt.value_or(0) : self.dense_dim();
999 auto sparse_row_dim = -(dense_dim + 2);
1000 auto sparse_col_dim = -(dense_dim + 1);
1001 if ((self.size(sparse_row_dim) % blocksize_to[0] != 0) ||
1002 (self.size(sparse_col_dim) % blocksize_to[1] != 0)) {
1003 AT_ERROR(funcname, ": tensor sparse size (", self.size(sparse_row_dim), ",", self.size(sparse_row_dim), ") must be divisible by given blocksize (", blocksize_to[0], ",", blocksize_to[1], ")");
1004 }
1005 }
1006 } else {
1007 AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with blocksize argument given is not supported");
1008 }
1009 } else {
1010 if ((layout_to == kSparseBsr || layout_to == kSparseBsc) &&
1011 !(layout_from == kSparseBsr && layout_from == kSparseBsc)) {
1012 AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " without blocksize argument given is not supported");
1013 }
1014 }
1015
1016 if (dense_dim_opt.has_value()) {
1017 if (layout_from != kStrided) {
1018 AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with dense_dim argument given is not supported");
1019 }
1020
1021 auto dense_dim = *dense_dim_opt;
1022 if (layout_to == kSparse) {
1023 if (dense_dim == self.dim() && self.dim() > 0) {
1024 AT_ERROR(funcname, ": dense_dim argument must be !=self.dim() when self.dim()>0");
1025 }
1026 if (dense_dim < 0 || dense_dim > self.dim()) {
1027 AT_ERROR(funcname, ": dense_dim argument must be in [0,", self.dim(), "] range, but ", dense_dim, " is given");
1028 }
1029 } else {
1030 if (dense_dim < 0 || dense_dim > self.dim() - 2) {
1031 AT_ERROR(funcname, ": dense_dim argument must be in [0,", self.dim() - 2, "] range, but ", dense_dim, " is given");
1032 }
1033 }
1034 }
1035 }
1036
1037 template<Layout target_layout>
dense_to_sparse_compressed(const Tensor & self,const Tensor & self_mask,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1038 static Tensor dense_to_sparse_compressed(const Tensor& self, const Tensor& self_mask, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1039 static_assert(target_layout == Layout::SparseCsr || target_layout == Layout::SparseCsc
1040 || target_layout == Layout::SparseBsr || target_layout == Layout::SparseBsc,
1041 "invalid layout template parameter for dense_to_sparse_compressed");
1042 constexpr auto compressed_rows_layout = target_layout == Layout::SparseCsr || target_layout == Layout::SparseBsr;
1043 constexpr auto blocked_layout = target_layout == Layout::SparseBsr || target_layout == Layout::SparseBsc;
1044
1045 int64_t dense_dim = dense_dim_opt.value_or(0);
1046
1047 // Reshape values so that the block dims are explicitly added, and
1048 // calculate a mask tensor that has only batch and sparse dims, and
1049 // value true whenever sparse matrix has a non-zero element over
1050 // corresponding block and dense dims, and false otherwise.
1051 auto n_batch_dim = self.dim() - 2 - dense_dim;
1052 auto is_batched = n_batch_dim > 0;
1053 auto values = blocked_layout ? _batch_tile_tensor(self, blocksize, dense_dim) : self;
1054 auto not_zero_mask = blocked_layout ? _batch_tile_tensor(self_mask, blocksize, dense_dim) : self_mask;
1055 if (blocked_layout || dense_dim > 0) {
1056 std::vector<int64_t> reduce_dim((blocked_layout ? 2 : 0) + dense_dim);
1057 std::iota(reduce_dim.begin(), reduce_dim.end(), n_batch_dim + 2);
1058 not_zero_mask = not_zero_mask.sum(reduce_dim) != 0;
1059 }
1060
1061 if (is_batched) {
1062 // Prepare for the conversion, in particular join the batch dims
1063 // and the compressed dim into the single dim.
1064 dense_to_sparse_compressed_prepare_check_mask_values_batched(
1065 target_layout, values, not_zero_mask, n_batch_dim);
1066 }
1067
1068 // Calculate sparse matrix row and col indices and then, depending
1069 // on the target layout, corresponding compressed and sparse
1070 // indices. Use the mask tensor calculate above to generate sparse
1071 // matrix values tensor.
1072 Tensor row_indices;
1073 Tensor col_indices;
1074 Tensor compressed_indices;
1075 if (compressed_rows_layout) {
1076 std::tie(col_indices, row_indices) = _not_zero_mask_to_col_row_indices(
1077 not_zero_mask, at::kLong, not_zero_mask.device());
1078 compressed_indices = at::_convert_indices_from_coo_to_csr(
1079 row_indices, not_zero_mask.size(0), false /*out_int32*/);
1080 {
1081 auto mask_indices = _mask_to_indices(not_zero_mask.flatten());
1082 values = values.flatten(0, 1).index_select(0, mask_indices);
1083 }
1084 } else {
1085 std::tie(row_indices, col_indices) = _not_zero_mask_to_col_row_indices(
1086 not_zero_mask.transpose(1, 0), at::kLong, not_zero_mask.device());
1087 compressed_indices = at::_convert_indices_from_coo_to_csr(
1088 col_indices, not_zero_mask.size(-1), false /*out_int32*/);
1089 {
1090 auto mask_indices = _mask_to_indices(not_zero_mask.transpose(0, 1).flatten());
1091 values = values.transpose(0, 1).flatten(0, 1).index_select(0, mask_indices);
1092 }
1093 }
1094 Tensor& plain_indices = compressed_rows_layout ? col_indices : row_indices;
1095
1096 if (is_batched) {
1097 // Restore the batch dims and compressed dim.
1098 reshape_2d_sparse_compressed_members_to_nd_batched(
1099 self.sizes(), n_batch_dim, compressed_indices, plain_indices, values);
1100 }
1101
1102 // Create compressed sparse matrix with the target layout.
1103 return at::_sparse_compressed_tensor_unsafe(
1104 compressed_indices,
1105 plain_indices,
1106 values,
1107 self.sizes(),
1108 self.options().layout(target_layout));
1109 }
1110
dense_to_sparse_with_mask(const Tensor & self,const Tensor & mask,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1111 Tensor dense_to_sparse_with_mask(const Tensor& self, const Tensor& mask, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1112 auto layout_to = layout.value_or(kSparse);
1113 TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "dense_to_sparse: unexpected same input and output layout");
1114 TORCH_INTERNAL_ASSERT(self.layout() == mask.layout(),
1115 "dense_to_sparse_with_mask: expected mask layout ", self.layout(), ", got ", mask.layout());
1116 TORCH_INTERNAL_ASSERT(self.sizes() == mask.sizes(),
1117 "dense_to_sparse_with_mask: expected mask size ", self.sizes(), ", got ", mask.sizes());
1118 _to_sparse_check_arguments("dense_to_sparse_with_mask", self, layout, blocksize, dense_dim_opt);
1119
1120 switch (layout_to) {
1121 case kSparse:
1122 return self.sparse_mask(mask.to_sparse(self.dim() - dense_dim_opt.value_or(0)));
1123 case kSparseCsr:
1124 return dense_to_sparse_compressed<Layout::SparseCsr>(self, mask, {}, dense_dim_opt);
1125 case kSparseCsc:
1126 return dense_to_sparse_compressed<Layout::SparseCsc>(self, mask, {}, dense_dim_opt);
1127 case kSparseBsr:
1128 return dense_to_sparse_compressed<Layout::SparseBsr>(self, mask, *blocksize, dense_dim_opt);
1129 case kSparseBsc:
1130 return dense_to_sparse_compressed<Layout::SparseBsc>(self, mask, *blocksize, dense_dim_opt);
1131 default:
1132 break;
1133 }
1134
1135 AT_ERROR("dense_to_sparse_with_mask: ", self.layout(), " to ", layout_to, " conversion not supported");
1136 return Tensor{};
1137 }
1138
dense_to_sparse_csr(const Tensor & self,std::optional<int64_t> dense_dim_opt)1139 Tensor dense_to_sparse_csr(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1140 auto layout_to = kSparseCsr;
1141 _to_sparse_check_arguments("dense_to_sparse_csr", self, layout_to, {}, dense_dim_opt);
1142
1143 return dense_to_sparse_compressed<Layout::SparseCsr>(self, self != 0, {}, dense_dim_opt);
1144 }
1145
dense_to_sparse_csc(const Tensor & self,std::optional<int64_t> dense_dim_opt)1146 Tensor dense_to_sparse_csc(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1147 auto layout_to = kSparseCsc;
1148 _to_sparse_check_arguments("dense_to_sparse_csc", self, layout_to, {}, dense_dim_opt);
1149
1150 return dense_to_sparse_compressed<Layout::SparseCsc>(self, self != 0, {}, dense_dim_opt);
1151 }
1152
dense_to_sparse_bsr(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1153 Tensor dense_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1154 auto layout_to = kSparseBsr;
1155 _to_sparse_check_arguments("dense_to_sparse_bsr", self, layout_to, blocksize, dense_dim_opt);
1156
1157 return dense_to_sparse_compressed<Layout::SparseBsr>(self, self != 0, blocksize, dense_dim_opt);
1158 }
1159
dense_to_sparse_bsc(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1160 Tensor dense_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1161 auto layout_to = kSparseBsc;
1162 _to_sparse_check_arguments("dense_to_sparse_bsc", self, layout_to, blocksize, dense_dim_opt);
1163
1164 return dense_to_sparse_compressed<Layout::SparseBsc>(self, self != 0, blocksize, dense_dim_opt);
1165 }
1166
dense_to_sparse(const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1167 Tensor dense_to_sparse(const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1168 auto layout_to = layout.value_or(kSparse);
1169 TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "dense_to_sparse: unexpected same input and output layout");
1170 _to_sparse_check_arguments("dense_to_sparse", self, layout, blocksize, dense_dim_opt);
1171
1172 switch (layout_to) {
1173 case kSparse:
1174 return self.to_sparse(self.dim() - dense_dim_opt.value_or(0));
1175 case kSparseCsr:
1176 return self.to_sparse_csr(dense_dim_opt);
1177 case kSparseCsc:
1178 return self.to_sparse_csc(dense_dim_opt);
1179 case kSparseBsr:
1180 return self.to_sparse_bsr(*blocksize, dense_dim_opt);
1181 case kSparseBsc:
1182 return self.to_sparse_bsc(*blocksize, dense_dim_opt);
1183 default:
1184 break;
1185 }
1186
1187 AT_ERROR("dense_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported");
1188 return Tensor{};
1189 }
1190
dense_to_sparse(const Tensor & self,int64_t sparse_dim)1191 Tensor dense_to_sparse(const Tensor& self, int64_t sparse_dim) {
1192 _to_sparse_check_arguments("dense_to_sparse", self, sparse_dim);
1193
1194 int64_t dims = self.dim();
1195 at::TensorOptions sparse_options = self.options().layout(kSparse);
1196 std::vector<int64_t> sizes = self.sizes().vec();
1197 Tensor nz = self.nonzero().transpose(0, 1);
1198 if (nz.size(1) == 0) {
1199 auto sparse = new_with_dims_sparse(
1200 sparse_dim,
1201 dims - sparse_dim,
1202 sizes,
1203 optTypeMetaToScalarType(sparse_options.dtype_opt()),
1204 sparse_options.layout_opt(),
1205 sparse_options.device_opt(),
1206 sparse_options.pinned_memory_opt());
1207 return sparse._coalesced_(true);
1208 }
1209 Tensor indices;
1210 if (sparse_dim == dims) {
1211 indices = nz.clone();
1212 } else {
1213 Tensor i = nz.narrow(0, 0, sparse_dim);
1214 std::tie(indices, std::ignore, std::ignore) = unique_dim(i, 1);
1215 indices = indices.contiguous(); // many sparse CUDA kernels require
1216 // contiguity, see issue #12633
1217 }
1218
1219 Tensor values;
1220 if (self.dim() > 0) {
1221 auto ix = toListOfOptionalTensors(indices.chunk(indices.size(0), 0));
1222 values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve);
1223 } else {
1224 AT_ASSERT(nz.sizes().equals({0, 1}));
1225 // In this cases, indices is a clone of nz, which is a tensor of shape (0,
1226 // 1). Given sparse tensor invariants, values should be shape (1,)
1227 values = self.unsqueeze(0).clone(at::MemoryFormat::Preserve);
1228 }
1229
1230 Tensor sparse = at::sparse_coo_tensor(indices, values, sizes, sparse_options);
1231 return sparse._coalesced_(true);
1232 }
1233
sparse_compressed_to_flipped(const Tensor & self,std::optional<IntArrayRef> blocksize,const std::string & name)1234 static Tensor sparse_compressed_to_flipped(
1235 const Tensor& self,
1236 std::optional<IntArrayRef> blocksize,
1237 const std::string& name) {
1238 const auto layout = self.layout();
1239 // NOTE: errors on non-compressed sparse layouts.
1240 const auto flipped_layout = at::sparse_csr::flip_compressed_layout(layout);
1241
1242 // Suppose compressed_indices represent rows of an input in either
1243 // CSR or BSR sparse compressed format.
1244 // In order to convert a batched CSR/BSR index into a batched CSC/BSC index
1245 // we perform the following steps:
1246 // 1. Convert a sparse compressed index representing batches of matrices of
1247 // shape (b, r, c) to a sparse compressed index that represents a single
1248 // matrix of shape (b * r, c).
1249 // 2. Turn the compressed indices of the matrix of shape (b * r, c) into
1250 // COO indices.
1251 // 3. Map these COO indices into the COO indices of a matrix of shape (r, b * c)
1252 // such that if A is a matrix of shape (b * r, c) and B is a matrix of shape
1253 // (r, b * c) such that
1254 // A[(k * r):(k * r + r), :] = B[:, (k * c):(k * c + c)] for all k in arange(b),
1255 // then A[i, j] = B[i', j'].
1256 // This is equivalent to finding indices that match values of matrices
1257 // tiled vertically to values of the same matrices tiled horizontally.
1258 // 4. Convert the COO indices to the CSC/BSC indices and form the output.
1259 //
1260 // NOTE: the reason behind vertical/horizontal tiling is to be able to transform
1261 // indices over all matrices in the batch in a single kernel call, since
1262 // all the existing coo <-> compressed indices conversion methods assume
1263 // a single matrix.
1264 //
1265 // CSC/BSC inputs are handled in a similar fashion with a "transposed" argument.
1266 // See the comments below for detailed explanations on how exactly each step
1267 // is performed.
1268
1269 Tensor compressed_indices, plain_indices;
1270 std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(self);
1271 auto values = self.values();
1272 const auto nnz = plain_indices.size(-1);
1273
1274 const auto n_batches = compressed_indices.dim() - 1;
1275 auto n_batches_nonzero = n_batches;
1276 // Insert fake batch dim for simplicity
1277 if (!n_batches) {
1278 n_batches_nonzero = 1;
1279 compressed_indices.unsqueeze_(0);
1280 plain_indices.unsqueeze_(0);
1281 values.unsqueeze_(0);
1282 }
1283
1284 // NOTE: these sparse_dim are true sparse dims only for CSR/CSC
1285 // inputs. And for BSR/BSC these are <true sparse dims> /
1286 // <blocksize>. In other words, sparse_dim stores ranges of valid
1287 // indices in the row/col dims.
1288 const auto sparse_dim = [&]() -> at::DimVector {
1289 auto sparse_dim = at::DimVector(self.sizes().slice(n_batches, 2));
1290 if (layout == at::kSparseBsr || layout == at::kSparseBsc) {
1291 auto blocksize = at::sparse_csr::getBlockSize(self);
1292 sparse_dim[0] /= blocksize[0];
1293 sparse_dim[1] /= blocksize[1];
1294 }
1295 return sparse_dim;
1296 }();
1297
1298 // batch_sizes_nonempty stores at least one, potentially fake, batch dimension.
1299 // rebatch_sizes_nonempty is equivalent to batch_sizes_nonempty.push_back(-1),
1300 // and is used to unflatten batch dimensions from a dimension of size
1301 // (batch_numel * dim_size,) for some dim_size.
1302 const auto batch_sizes_nonempty = at::DimVector(plain_indices.sizes().slice(0, n_batches_nonzero));
1303 auto rebatch_sizes_nonempty = at::DimVector(batch_sizes_nonempty);
1304 rebatch_sizes_nonempty.push_back(-1);
1305 const auto batch_numel_nonzero = std::accumulate(
1306 batch_sizes_nonempty.begin(),
1307 batch_sizes_nonempty.begin() + n_batches_nonzero,
1308 1,
1309 std::multiplies<int64_t>());
1310
1311 // Equivalent to (arange(batch_numel_nonzero).mul_(nnz)).reshape(batch_sizes_nonempty).
1312 // We just compute it differently to use `add` kernel in place of `mul` for better
1313 // performance.
1314 const auto batch_nnz_offset = [&]() -> Tensor {
1315 const auto wrapped_nnz = at::tensor({nnz}, compressed_indices.options());
1316 auto offset = wrapped_nnz
1317 .expand({batch_numel_nonzero})
1318 .cumsum(-1).sub_(wrapped_nnz)
1319 .reshape(batch_sizes_nonempty);
1320 return offset;
1321 }();
1322
1323 // Step 1 for CSR/BSR inputs:
1324 // Convert a sparse compressed index representing batches of matrices of
1325 // shape (b, r, c) to a sparse compressed index that represents a single
1326 // matrix of shape (b * r, c).
1327 // The algorithm is identical for CSC/BSC inputs, with the batch dimensions
1328 // flattened in the "transposed" dimension.
1329 const auto compressed_indices_2d = [&]() -> Tensor {
1330 // Extract offsets only relevant for the first :-1 elements in a row/col.
1331 const auto compressed_offsets = compressed_indices.slice(-1, 0, -1);
1332 // batch_offsets offsets each individual matrix row/col offsets by the total
1333 // sum of nnz's of all the matrices with the smaller batch index.
1334 const auto batch_offsets = batch_nnz_offset
1335 .unsqueeze(-1).expand_as(compressed_offsets);
1336 // compressed_offsets + batch_offsets creates an offset vector for a 2d matrix
1337 // that is stored in a compressed sparse format.
1338 const auto compressed_offsets_2d = compressed_offsets.add(batch_offsets).reshape({-1});
1339 const auto offsets_len = compressed_offsets_2d.numel();
1340 auto res = at::empty({offsets_len + 1}, compressed_indices.options());
1341 res.slice(-1, 0, -1).copy_(compressed_offsets_2d);
1342 // By appending nnz * batch_numel_nonzero to (compressed_offsets + batch_offsets)
1343 // a compressed index of a 2d matrix is formed.
1344 res.slice(-1, -1).fill_(nnz * batch_numel_nonzero);
1345 return res;
1346 }();
1347 // More involved for compressed indices, but pretty easy for plain_indices and values:
1348 // just squash batch dimensions.
1349 const auto plain_indices_2d = plain_indices.flatten(0, n_batches_nonzero);
1350 // NOTE: values are not 2d! They just represent values of a sparse compressed 2d matrix.
1351 const auto values_2d = values.flatten(0, n_batches_nonzero);
1352
1353 const auto is_out_int32 = compressed_indices.scalar_type() == ScalarType::Int;
1354
1355 // Step 2 & 3:
1356 //
1357 // Turn the compressed indices of the matrix of shape (b * r, c) into COO indices.
1358 //
1359 // Map these COO indices into the COO indices of a matrix of shape (r, b * c)
1360 // such that if A is a matrix of shape (b * r, c) and B is a matrix of shape
1361 // (r, b * c) such that
1362 // A[(k * r):(k * r + r), :] = B[:, (k * c):(k * c + c)] for all k in arange(b),
1363 // then A[i, j] = B[i', j'].
1364 // This is equivalent to finding indices that match values of matrices
1365 // tiled vertically to values of the same matrices tiled horizontally.
1366
1367 // coo <-> sparse index conversions assume CSR/BSR inputs.
1368 // To CSC/BSC inputs these indices will appear "transposed".
1369 const auto is_transposed_indices = layout == at::kSparseCsc || layout == at::kSparseBsc;
1370 const auto coo_indices_2d_transposed = [&]() -> Tensor {
1371 auto coo_indices_2d = _convert_indices_from_csr_to_coo(
1372 compressed_indices_2d,
1373 plain_indices_2d,
1374 is_out_int32,
1375 /*transpose=*/true); // Flip rows/cols for convenience.
1376 // Convert COO indices of (b * r, c) to (r, b * c).
1377 // It is a map (i, j) -> {
1378 // b = i // r
1379 // i' = i % r
1380 // j' = j + b * c
1381 // return (i', j')
1382 // }
1383 // NOTE: we used transposed=true above!
1384 auto i = coo_indices_2d.select(0, 1);
1385 auto j = coo_indices_2d.select(0, 0);
1386 auto b = i.div(is_transposed_indices ? sparse_dim[1] : sparse_dim[0], "trunc");
1387 // Modify i, j in-place.
1388 i.fmod_(is_transposed_indices ? sparse_dim[1] : sparse_dim[0]);
1389 j.add_(b * (is_transposed_indices ? sparse_dim[0] : sparse_dim[1]));
1390 return coo_indices_2d;
1391 }();
1392
1393 // Step 4:
1394 // Convert the COO indices to the CSC/BSC indices and form the output.
1395 // We need to sort COO indices along the "tranposed" dim to satisfy the
1396 // invariant of sorted plain indices.
1397 // Hash coo indices by converting 2d indices to linear offsets with
1398 // more "weight" (aka stride) placed on the "transposed" dimension.
1399 const auto coo_indices_2d_transposed_hashed = at::sparse::flatten_indices(
1400 coo_indices_2d_transposed,
1401 is_transposed_indices ? at::DimVector({sparse_dim[0], sparse_dim[1] * batch_numel_nonzero})
1402 : at::DimVector({sparse_dim[1], sparse_dim[0] * batch_numel_nonzero}));
1403 const auto hash_argsort = std::get<1>(coo_indices_2d_transposed_hashed.sort());
1404 const auto coo_indices_2d_transposed_sorted = coo_indices_2d_transposed.index_select(1, hash_argsort);
1405
1406 const auto new_compressed_indices_coo_2d = coo_indices_2d_transposed_sorted.select(0, 0);
1407 const auto new_plain_indices_2d = coo_indices_2d_transposed_sorted.select(0, 1);
1408 const auto new_values_2d = values_2d.index_select(0, hash_argsort);
1409
1410 auto new_compressed_indices = compressed_to_batched_compressed_indices(
1411 _convert_indices_from_coo_to_csr(
1412 new_compressed_indices_coo_2d,
1413 is_transposed_indices
1414 ? batch_numel_nonzero * sparse_dim[0]
1415 : batch_numel_nonzero * sparse_dim[1],
1416 is_out_int32),
1417 batch_numel_nonzero,
1418 is_out_int32)
1419 .unflatten(0, batch_sizes_nonempty);
1420 auto new_plain_indices = new_plain_indices_2d.unflatten(0, rebatch_sizes_nonempty);
1421 auto new_values = new_values_2d.unflatten(0, rebatch_sizes_nonempty);
1422 // Kill fake batch dim if it was inserted.
1423 if (!n_batches) {
1424 new_compressed_indices.squeeze_(0);
1425 new_plain_indices.squeeze_(0);
1426 new_values.squeeze_(0);
1427 }
1428
1429 return _sparse_compressed_tensor_unsafe(
1430 new_compressed_indices,
1431 new_plain_indices,
1432 new_values,
1433 self.sizes(),
1434 self.options().layout(flipped_layout));
1435 }
1436
sparse_compressed_to_sparse_csr(const Tensor & self,std::optional<int64_t> dense_dim_opt)1437 Tensor sparse_compressed_to_sparse_csr(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1438 auto layout_to = kSparseCsr;
1439 TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse_csr: unexpected same input and output layout");
1440 _to_sparse_check_arguments("sparse_compressed_to_sparse_csr", self, layout_to, {}, dense_dim_opt);
1441
1442 if (self.layout() == kSparseCsc) {
1443 return sparse_compressed_to_flipped(self, std::nullopt, "to_sparse_csr");
1444 }
1445
1446 AT_ERROR("sparse_compressed_to_sparse_csr: expected SparseCsr or SparseCsc layout but got ", self.layout());
1447 return Tensor{};
1448 }
1449
sparse_compressed_to_sparse_csc(const Tensor & self,std::optional<int64_t> dense_dim_opt)1450 Tensor sparse_compressed_to_sparse_csc(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1451 auto layout_to = kSparseCsc;
1452 TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse_csc: unexpected same input and output layout");
1453 _to_sparse_check_arguments("sparse_compressed_to_sparse_csc", self, layout_to, {}, dense_dim_opt);
1454
1455 if (self.layout() == kSparseCsr) {
1456 return sparse_compressed_to_flipped(self, std::nullopt, "to_sparse_csc");
1457 }
1458
1459 AT_ERROR("sparse_compressed_to_sparse_csc: expected SparseCsr or SparseCsc layout but got ", self.layout());
1460 return Tensor{};
1461 }
1462
coo_to_sparse_csr(const Tensor & self,std::optional<int64_t> dense_dim_opt)1463 Tensor coo_to_sparse_csr(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1464 auto layout_to = kSparseCsr;
1465 _to_sparse_check_arguments("coo_to_sparse_csr", self, layout_to, {}, dense_dim_opt);
1466
1467 auto coalesced_self = self.coalesce();
1468 auto row_indices = coalesced_self.indices()[0];
1469 bool out_int32 = (row_indices.scalar_type() == at::kInt);
1470 auto crow_indices = at::_convert_indices_from_coo_to_csr(
1471 row_indices, self.size(0), out_int32);
1472 return at::native::_sparse_csr_tensor_unsafe(
1473 crow_indices,
1474 coalesced_self.indices()[1].contiguous(),
1475 coalesced_self.values(),
1476 coalesced_self.sizes(),
1477 coalesced_self.scalar_type(),
1478 c10::kSparseCsr,
1479 coalesced_self.device());
1480 }
1481
coo_to_sparse_csc(const Tensor & self,std::optional<int64_t> dense_dim_opt)1482 Tensor coo_to_sparse_csc(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1483 auto layout_to = kSparseCsc;
1484 _to_sparse_check_arguments("coo_to_sparse_csc", self, layout_to, {}, dense_dim_opt);
1485
1486 auto transposed_csr = self.transpose(0, 1).to_sparse_csr(dense_dim_opt);
1487 return at::native::_sparse_csc_tensor_unsafe(
1488 transposed_csr.crow_indices(),
1489 transposed_csr.col_indices(),
1490 transposed_csr.values(),
1491 self.sizes(),
1492 transposed_csr.scalar_type(),
1493 c10::kSparseCsc,
1494 transposed_csr.device());
1495 }
1496
coo_to_sparse_bsr(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1497 Tensor coo_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1498 auto layout_to = kSparseBsr;
1499 _to_sparse_check_arguments("coo_to_sparse_bsr", self, layout_to, blocksize, dense_dim_opt);
1500
1501 return self.to_sparse_csr(dense_dim_opt).to_sparse_bsr(blocksize);
1502 }
1503
coo_to_sparse_bsc(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1504 Tensor coo_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1505 auto layout_to = kSparseBsc;
1506 _to_sparse_check_arguments("coo_to_sparse_bsc", self, layout_to, blocksize, dense_dim_opt);
1507
1508 return self.to_sparse_csc(dense_dim_opt).to_sparse_bsc(blocksize);
1509 }
1510
1511 namespace {
1512 template <typename input_t, typename output_t>
convert_indices_from_coo_to_csr_cpu(const Tensor & result,const Tensor & input,const int64_t size)1513 void convert_indices_from_coo_to_csr_cpu(
1514 const Tensor& result,
1515 const Tensor& input,
1516 const int64_t size) {
1517 int64_t numel = input.numel();
1518 const input_t* data_in = input.const_data_ptr<input_t>();
1519 output_t* data_out = result.data_ptr<output_t>();
1520
1521 if (numel == 0) {
1522 result.zero_();
1523 return;
1524 }
1525
1526 for (int64_t i = 0; i <= data_in[0]; i++)
1527 data_out[i] = static_cast<output_t>(0);
1528
1529 at::parallel_for(
1530 0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1531 input_t curr_value = data_in[start], next_value;
1532 for (const auto i : c10::irange(start, end)) {
1533 next_value = data_in[i + 1];
1534 for (; curr_value < next_value; curr_value++)
1535 data_out[curr_value + 1] = static_cast<output_t>(i + 1);
1536 }
1537 });
1538 for (int64_t i = data_in[numel - 1] + 1; i < size + 1; i++) {
1539 data_out[i] = static_cast<output_t>(numel);
1540 }
1541 }
1542
1543 template <typename input_t, typename output_t>
convert_indices_from_csr_to_coo_cpu(const Tensor & indices,const Tensor & crow_indices,const Tensor & col_indices,const bool transpose=false)1544 void convert_indices_from_csr_to_coo_cpu(
1545 const Tensor& indices,
1546 const Tensor& crow_indices,
1547 const Tensor& col_indices,
1548 const bool transpose = false) {
1549 int64_t nrows = crow_indices.size(-1) - 1;
1550 int64_t nnz = col_indices.size(-1);
1551 if (nrows == 0 || nnz == 0) {
1552 indices.zero_(); // is this needed as indices has a zero-valued
1553 // dimension when nrows or nnz is 0?
1554 return;
1555 }
1556 auto crow_indices_ = crow_indices.expect_contiguous();
1557 int64_t total_nnz = col_indices.numel();
1558 int64_t batch_ndim = crow_indices.dim() - 1;
1559 if (batch_ndim > 0) {
1560 auto batch_indices = indices.narrow(0, 0, batch_ndim);
1561 batch_indices.copy_(at::sparse::full_coo_indices(crow_indices.sizes().slice(0, batch_ndim), crow_indices.options())
1562 .repeat_interleave(nnz, 1));
1563 }
1564 const input_t* crow_indices_data_in = crow_indices_->const_data_ptr<input_t>();
1565 TORCH_INTERNAL_ASSERT(indices.is_contiguous());
1566 auto row0 = indices.select(0, transpose ? batch_ndim + 1 : batch_ndim + 0);
1567 auto row1 = indices.select(0, transpose ? batch_ndim + 0 : batch_ndim + 1);
1568 output_t* data_out = row0.data_ptr<output_t>();
1569 auto col_indices_ = col_indices.expect_contiguous();
1570 row1.copy_(col_indices_->view({-1}));
1571 at::parallel_for(
1572 0, nrows * total_nnz / nnz, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1573 for (const auto i_ : c10::irange(start, end)) {
1574 auto b = i_ / nrows;
1575 auto i = i_ % nrows;
1576 std::fill(
1577 &data_out[b * nnz + crow_indices_data_in[b * (nrows + 1) + i]],
1578 &data_out[b * nnz + crow_indices_data_in[b * (nrows + 1) + i + 1]],
1579 static_cast<output_t>(i));
1580 }
1581 });
1582 }
1583 } // namespace
1584
TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_cpu)1585 TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_cpu)
1586 (const Tensor& input,
1587 const int64_t size,
1588 const bool out_int32,
1589 const Tensor& result) {
1590 if (out_int32) {
1591 AT_DISPATCH_INTEGRAL_TYPES(
1592 input.scalar_type(), "convert_indices_from_coo_to_csr_cpu", [&] {
1593 convert_indices_from_coo_to_csr_cpu<scalar_t, int32_t>(
1594 result, input, size);
1595 });
1596 } else {
1597 AT_DISPATCH_INTEGRAL_TYPES(
1598 input.scalar_type(), "convert_indices_from_coo_to_csr_cpu", [&] {
1599 convert_indices_from_coo_to_csr_cpu<scalar_t, int64_t>(
1600 result, input, size);
1601 });
1602 }
1603 }
1604
TORCH_IMPL_FUNC(_convert_indices_from_csr_to_coo_structured_cpu)1605 TORCH_IMPL_FUNC(_convert_indices_from_csr_to_coo_structured_cpu)
1606 (const Tensor& crow_indices,
1607 const Tensor& col_indices,
1608 const bool out_int32,
1609 const bool transpose,
1610 const Tensor& result) {
1611 if (out_int32) {
1612 AT_DISPATCH_INTEGRAL_TYPES(
1613 crow_indices.scalar_type(), "convert_indices_from_csr_to_coo_cpu", [&] {
1614 convert_indices_from_csr_to_coo_cpu<scalar_t, int32_t>(
1615 result, crow_indices, col_indices, transpose);
1616 });
1617 } else {
1618 AT_DISPATCH_INTEGRAL_TYPES(
1619 crow_indices.scalar_type(), "convert_indices_from_csr_to_coo_cpu", [&] {
1620 convert_indices_from_csr_to_coo_cpu<scalar_t, int64_t>(
1621 result, crow_indices, col_indices, transpose);
1622 });
1623 }
1624 }
1625
1626 /*
1627 * Based on
1628 * https://github.com/scipy/scipy/blob/8a64c938ddf1ae4c02a08d2c5e38daeb8d061d38/scipy/sparse/sparsetools/csr.h
1629 * Modified to ensure sorted BSR column indices.
1630 */
1631 template <class index_t, class scalar_t, bool compressed_rows>
_compressed_to_block_compressed_cpu_kernel(const index_t n_compressed,const index_t n_plain,const index_t C,const index_t P,const index_t D,const index_t * input_compressed_indices,const index_t * input_plain_indices,const scalar_t * input_values,index_t * result_compressed_indices,index_t * result_plain_indices,scalar_t * result_values)1632 void _compressed_to_block_compressed_cpu_kernel(
1633 const index_t n_compressed, // Tensor size along compressed dimension
1634 const index_t n_plain, // Tensor size along plain dimension
1635 const index_t C, // Block size along compressed dimensions
1636 const index_t P, // Block size along plain dimension
1637 const index_t D, // Number of elements in dense dimensions
1638 const index_t* input_compressed_indices,
1639 const index_t* input_plain_indices,
1640 const scalar_t* input_values,
1641 index_t* result_compressed_indices,
1642 index_t* result_plain_indices,
1643 scalar_t* result_values) {
1644 // All blocks are possible, that is, may be allocated if a single
1645 // non-zero value lives within them. Otherwise they're not.
1646
1647 // Allocate pointers for all possible plain blocks plus 1
1648 std::vector<scalar_t*> blocks(n_plain / P + 1, nullptr);
1649
1650 assert(n_compressed % C == 0);
1651 assert(n_plain % P == 0);
1652
1653 // Number of blocks along compressed dim
1654 index_t n_bcompressed = n_compressed / C;
1655 // Number of blocks along plain_dim
1656 index_t n_bplain = n_plain / P;
1657
1658 // Number of elements per block
1659 index_t CPD = C * P * D;
1660 // Number of blocks overall
1661 index_t n_blks = 0;
1662
1663 result_compressed_indices[0] = 0;
1664
1665 // Iterate over blocks along compressed dim
1666 for (index_t block_c = 0; block_c < n_bcompressed; block_c++) {
1667 // Iterate over blocks along plain dim to locate non-zero blocks,
1668 // this guarantees sorted plain dim indices
1669 for (index_t block_p = 0; block_p < n_bplain; block_p ++) {
1670 for (index_t i = input_compressed_indices[C * block_c]; i < input_compressed_indices[C * (block_c + 1)]; i++) {
1671 index_t p = input_plain_indices[i]; // plain dim element index
1672 if (p / P == block_p) {
1673 blocks[block_p] = result_values + CPD * n_blks;
1674 result_plain_indices[n_blks] = block_p;
1675 n_blks++;
1676 break;
1677 }
1678 }
1679 }
1680
1681 // Iterate over compressed dim within block
1682 for (index_t cb = 0; cb < C; cb++) {
1683 index_t c = C * block_c + cb; // compressed dim index
1684 for (index_t i = input_compressed_indices[c]; i < input_compressed_indices[c + 1]; i++) {
1685 index_t p = input_plain_indices[i]; // plain dim index
1686
1687 // Block corresponding to plain dim index
1688 index_t block_p = p / P;
1689 // Plain dim index within block
1690 index_t pb = p % P;
1691
1692 // Specific blocks entries should not be visited more than
1693 // once. Scipy code does an addition here. Why?
1694 // A possible answer: Scipy code supports "uncoalesced CSR"
1695 // format that allows repeated plain dim indices, and
1696 // compressed and plain indices may be unsorted.
1697 std::copy(input_values + i * D, input_values + (i + 1) * D,
1698 blocks[block_p] + (compressed_rows ? P * cb + pb : C * pb + cb) * D);
1699 }
1700 }
1701
1702 // Scipy code has
1703 /*
1704 for (I i = input_compressed_indices[C * block_c];
1705 i < input_compressed_indices[C * (block_c + 1)];
1706 i++) {
1707 blocks[input_plain_indices[i] / P] = 0;
1708 }
1709 */
1710 // but we don't need it because the modified code (see the block_p
1711 // loop above) does not need to evaluate `blocks[block_p] == 0`
1712 // that the original code did.
1713 result_compressed_indices[block_c + 1] = n_blks;
1714 }
1715 }
1716
1717 /*
1718 * Based on
1719 * https://github.com/scipy/scipy/blob/8a64c938ddf1ae4c02a08d2c5e38daeb8d061d38/scipy/sparse/sparsetools/csr.h
1720 */
1721 template <class index_t>
compressed_count_blocks(const index_t n_compressed,const index_t n_plain,const index_t C,const index_t P,const index_t Ac[],const index_t Ap[])1722 index_t compressed_count_blocks(
1723 const index_t n_compressed, // Tensor size along compressed dimension
1724 const index_t n_plain, // Tensor size along plain dimension
1725 const index_t C, // Block size along compressed dimensions
1726 const index_t P, // Block size along plain dimension
1727 const index_t Ac[], // Compressed indices
1728 const index_t Ap[] // Plain indices
1729 ) {
1730 std::vector<index_t> mask(n_plain / P + 1, -1);
1731 index_t n_blks = 0;
1732 for (index_t c = 0; c < n_compressed; c++) {
1733 index_t bc = c / C;
1734 for (index_t i = Ac[c]; i < Ac[c + 1]; i++) {
1735 index_t bp = Ap[i] / P;
1736 if (mask[bp] != bc) {
1737 mask[bp] = bc;
1738 n_blks++;
1739 }
1740 }
1741 }
1742 return n_blks;
1743 }
1744
1745 template<Layout target_layout>
_compressed_to_block_compressed_cpu(const Tensor & self,IntArrayRef blocksize)1746 Tensor _compressed_to_block_compressed_cpu(const Tensor& self, IntArrayRef blocksize) {
1747 static_assert(target_layout == Layout::SparseBsr || target_layout == Layout::SparseBsc,
1748 "invalid layout template parameter for _compressed_to_block_compressed_cpu");
1749
1750 auto input_values = self.values().contiguous();
1751 Tensor input_compressed_indices;
1752 Tensor input_plain_indices;
1753 std::tie(input_compressed_indices, input_plain_indices) = sparse_csr::getCompressedPlainIndices(self);
1754 input_compressed_indices = input_compressed_indices.contiguous();
1755 input_plain_indices = input_plain_indices.contiguous();
1756
1757 // First we determine the number of blocks needed. For each given
1758 // block, if it contains a non-zero element we will allocate values
1759 // and indices for it.
1760 int64_t num_blocks = 0;
1761 auto compressed_dim = (target_layout == Layout::SparseBsr) ? self.size(0) : self.size(1);
1762 auto plain_dim = (target_layout == Layout::SparseBsr) ? self.size(1) : self.size(0);
1763 auto compressed_blocksize = (target_layout == Layout::SparseBsr) ? blocksize[0] : blocksize[1];
1764 auto plain_blocksize = (target_layout == Layout::SparseBsr) ? blocksize[1] : blocksize[0];
1765
1766 AT_DISPATCH_INDEX_TYPES(
1767 input_compressed_indices.scalar_type(), "_compressed_to_block_compressed_cpu", [&] {
1768 num_blocks =
1769 compressed_count_blocks<index_t>(
1770 compressed_dim,
1771 plain_dim,
1772 compressed_blocksize,
1773 plain_blocksize,
1774 input_compressed_indices.data_ptr<index_t>(),
1775 input_plain_indices.data_ptr<index_t>());
1776 });
1777 DimVector dense_shape{input_values.sizes().slice(1, input_values.dim() - 1)};
1778 DimVector values_shape{num_blocks, blocksize[0], blocksize[1]};
1779 values_shape.append(dense_shape);
1780
1781 Tensor result_values = input_values.new_zeros(values_shape);
1782 Tensor result_compressed_indices =
1783 input_compressed_indices.new_empty({compressed_dim /compressed_blocksize + 1});
1784 Tensor result_plain_indices = input_plain_indices.new_empty({num_blocks});
1785
1786 // Next we copy over non-zero elements into the allocated blocks.
1787 auto n_dense = std::accumulate(
1788 dense_shape.begin(), dense_shape.end(), 1, std::multiplies<int64_t>());
1789 AT_DISPATCH_INDEX_TYPES(
1790 input_compressed_indices.scalar_type(), "_compressed_to_block_compressed_cpu", [&] {
1791 AT_DISPATCH_SPARSE_VALUE_TYPES(
1792 input_values.scalar_type(), "_compressed_to_block_compressed_cpu", [&] {
1793 _compressed_to_block_compressed_cpu_kernel<index_t, scalar_t, target_layout == Layout::SparseBsr>(
1794 compressed_dim,
1795 plain_dim,
1796 compressed_blocksize,
1797 plain_blocksize,
1798 n_dense,
1799 input_compressed_indices.data_ptr<index_t>(),
1800 input_plain_indices.data_ptr<index_t>(),
1801 input_values.data_ptr<scalar_t>(),
1802 result_compressed_indices.data_ptr<index_t>(),
1803 result_plain_indices.data_ptr<index_t>(),
1804 result_values.data_ptr<scalar_t>());
1805 });
1806 });
1807
1808 return at::_sparse_compressed_tensor_unsafe(
1809 result_compressed_indices,
1810 result_plain_indices,
1811 result_values,
1812 self.sizes(),
1813 self.options().layout(target_layout));
1814 }
1815
sparse_compressed_to_sparse_bsr(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1816 Tensor sparse_compressed_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1817 auto layout_to = kSparseBsr;
1818 TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse_bsr: unexpected same input and output layout");
1819 _to_sparse_check_arguments("sparse_compressed_to_sparse_bsr", self, layout_to, blocksize, dense_dim_opt);
1820
1821 if (self.layout() == kSparseBsc) {
1822 return sparse_compressed_to_flipped(self, blocksize, "to_sparse_bsr");
1823 }
1824 if (self.layout() == kSparseCsr) {
1825 if (self.device() != kCPU) {
1826 TORCH_WARN("sparse_compressed_to_sparse_bsr executing on the CPU device, the performance may be sub-optimal");
1827 }
1828 return _compressed_to_block_compressed_cpu<kSparseBsr>(self.cpu(), blocksize).to(self.device());
1829 }
1830 if (self.layout() == kSparseCsc) {
1831 return self.to_sparse_csr(dense_dim_opt).to_sparse_bsr(blocksize);
1832 }
1833
1834 AT_ERROR("sparse_compressed_to_sparse_bsr: expected SparseCsr, SparseCsc, SparseBsr or SparseBsc layout but got ", self.layout());
1835 return Tensor{};
1836 }
1837
sparse_compressed_to_sparse_bsc(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1838 Tensor sparse_compressed_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1839 auto layout_to = kSparseBsc;
1840 TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse_bsc: unexpected same input and output layout");
1841 _to_sparse_check_arguments("sparse_compressed_to_sparse_bsc", self, layout_to, blocksize, dense_dim_opt);
1842
1843 if (self.layout() == kSparseBsr) {
1844 return sparse_compressed_to_flipped(self, blocksize, "to_sparse_bsc");
1845 }
1846 if (self.layout() == kSparseCsc) {
1847 if (self.device() != kCPU) {
1848 TORCH_WARN("sparse_compressed_to_sparse_bsc executing on the CPU device, the performance may be sub-optimal");
1849 }
1850 return _compressed_to_block_compressed_cpu<kSparseBsc>(self.cpu(), blocksize).to(self.device());
1851 }
1852 if (self.layout() == kSparseCsr) {
1853 return self.to_sparse_csc(dense_dim_opt).to_sparse_bsc(blocksize);
1854 }
1855
1856 AT_ERROR("sparse_compressed_to_sparse_bsc: expected SparseCsr, SparseCsc, SparseBsr or SparseBsc layout but got ", self.layout());
1857 return Tensor{};
1858 }
1859
sparse_coo_to_sparse(const Tensor & self,const int64_t sparse_dim)1860 Tensor sparse_coo_to_sparse(const Tensor& self, const int64_t sparse_dim) {
1861 _to_sparse_check_arguments("sparse_coo_to_sparse", self, sparse_dim);
1862
1863 AT_ERROR("sparse_coo_to_sparse: ", self.layout(), " to ", kSparse, " conversion not supported");
1864 return Tensor{};
1865 }
1866
sparse_compressed_to_sparse(const Tensor & self,const int64_t sparse_dim)1867 Tensor sparse_compressed_to_sparse(const Tensor& self, const int64_t sparse_dim) {
1868 _to_sparse_check_arguments("sparse_compressed_to_sparse", self, sparse_dim);
1869
1870 Layout layout = self.layout();
1871 auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(self);
1872 Tensor values;
1873 Tensor indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices,
1874 false, (layout == kSparseCsc || layout == kSparseBsc));
1875 const auto batch_ndim = compressed_indices.dim() - 1;
1876 // Only CSR is trivially coalesced
1877 bool coalesced = layout == kSparseCsr || self.numel() == 0 || self._nnz() == 1;
1878 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "sparse_compressed_to_sparse",
1879 [&] { values = self.values().flatten(0, batch_ndim); },
1880 [&] {
1881 auto blocksize = DimVector(self.values().sizes().slice(batch_ndim + 1, 2));
1882 DimVector batch_blocksize;
1883 batch_blocksize.append(batch_ndim, 1);
1884 batch_blocksize.append(blocksize);
1885 const auto block_coo_indices = at::zeros({batch_ndim + 2, blocksize[0] * blocksize[1]}, indices.options());
1886 block_coo_indices.narrow(0, batch_ndim, 2).copy_(at::sparse::full_coo_indices(blocksize, indices.options()));
1887 indices = indices
1888 // Scale indices that identify blocks to element-wise coordinates that correspond
1889 // to the top-left corner of each block.
1890 .mul(at::tensor(batch_blocksize, indices.options()).unsqueeze_(1))
1891 // Now that we know top-left block coordinates, we offset them with element-wise
1892 // coordinates in the block to get the result.
1893 // NOTE: indices is mapped from (dim, nnz) to (dim, nnz, 1),
1894 // and block_coo_indices is mapped from (dim, block_numel) to
1895 // (dim, 1, block_numel), so the result has shape
1896 // (dim, nnz, block_numel).
1897 .unsqueeze_(-1).add(block_coo_indices.unsqueeze_(1))
1898 // Squash the nnz and the block_numel dimension
1899 // to produce valid nnz dimension of a COO tensor.
1900 .flatten(-2, -1);
1901
1902 values = self.values().flatten(0, batch_ndim + 2);
1903
1904 // BSRs not spanning across several rows produces coalesced results.
1905 coalesced |= (layout == kSparseBsr && blocksize[0] == 1 && batch_ndim == 0);
1906 });
1907 return at::native::_sparse_coo_tensor_unsafe(indices, values, self.sizes())._coalesced_(coalesced);
1908 }
1909
sparse_compressed_to_sparse(const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1910 Tensor sparse_compressed_to_sparse(const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1911 auto layout_to = layout.value_or(kSparse);
1912 TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse: unexpected same input and output layout");
1913 _to_sparse_check_arguments("sparse_compressed_to_sparse", self, layout_to, blocksize, dense_dim_opt);
1914
1915 auto blocksize_ = blocksize.value_or((self.layout() == kSparseBsr || self.layout() == kSparseBsc) ? at::sparse_csr::getBlockSize(self) : at::DimVector({1, 1}));
1916 switch (layout_to) {
1917 case kStrided:
1918 return sparse_compressed_to_dense(self, /*dtype=*/std::nullopt, /*masked_grad=*/std::nullopt);
1919 case kSparse:
1920 return sparse_compressed_to_sparse(self, 2);
1921 case kSparseCsr:
1922 return sparse_compressed_to_sparse_csr(self, dense_dim_opt);
1923 case kSparseCsc:
1924 return sparse_compressed_to_sparse_csc(self, dense_dim_opt);
1925 case kSparseBsr:
1926 return sparse_compressed_to_sparse_bsr(self, blocksize_, dense_dim_opt);
1927 case kSparseBsc:
1928 return sparse_compressed_to_sparse_bsc(self, blocksize_, dense_dim_opt);
1929 default:
1930 break;
1931 }
1932
1933 AT_ERROR("sparse_compressed_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported");
1934 return Tensor{};
1935 }
1936
sparse_coo_to_sparse(const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1937 Tensor sparse_coo_to_sparse(const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1938 auto layout_to = layout.value_or(kSparse);
1939 TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_coo_to_sparse: unexpected same input and output layout");
1940 _to_sparse_check_arguments("sparse_coo_to_sparse", self, layout_to, blocksize, dense_dim_opt);
1941
1942 switch (layout_to) {
1943 case kStrided:
1944 return self.to_dense(std::nullopt, std::nullopt);
1945 case kSparseCsr:
1946 return self.to_sparse_csr(dense_dim_opt);
1947 case kSparseCsc:
1948 return self.to_sparse_csc(dense_dim_opt);
1949 case kSparseBsr:
1950 return self.to_sparse_bsr(*blocksize, dense_dim_opt);
1951 case kSparseBsc:
1952 return self.to_sparse_bsc(*blocksize, dense_dim_opt);
1953 default:
1954 break;
1955 }
1956
1957 AT_ERROR("sparse_coo_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported");
1958 return Tensor{};
1959 }
1960
to_sparse(const Tensor & self,const int64_t sparse_dim)1961 Tensor to_sparse(const Tensor& self, const int64_t sparse_dim) {
1962 auto layout_to = kSparse;
1963 if (self.layout() == layout_to) {
1964 _to_sparse_check_arguments("to_sparse", self, sparse_dim);
1965 return self;
1966 }
1967 return self._to_sparse(sparse_dim);
1968 }
1969
to_sparse(const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1970 Tensor to_sparse(const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1971 auto layout_to = layout.value_or(kSparse);
1972 if (self.layout() == layout_to) {
1973 _to_sparse_check_arguments("to_sparse", self, layout, blocksize, dense_dim_opt);
1974 return self;
1975 }
1976 return self._to_sparse(layout, blocksize, dense_dim_opt);
1977 }
1978
to_sparse_csr(const Tensor & self,std::optional<int64_t> dense_dim_opt)1979 Tensor to_sparse_csr(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1980 auto layout_to = kSparseCsr;
1981 if (self.layout() == layout_to) {
1982 _to_sparse_check_arguments("to_sparse_csr", self, layout_to, {}, dense_dim_opt);
1983 return self;
1984 }
1985 return self._to_sparse_csr(dense_dim_opt);
1986 }
1987
to_sparse_csc(const Tensor & self,std::optional<int64_t> dense_dim_opt)1988 Tensor to_sparse_csc(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1989 auto layout_to = kSparseCsc;
1990 if (self.layout() == layout_to) {
1991 _to_sparse_check_arguments("to_sparse_csc", self, layout_to, {}, dense_dim_opt);
1992 return self;
1993 }
1994 return self._to_sparse_csc(dense_dim_opt);
1995 }
1996
to_sparse_bsr(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1997 Tensor to_sparse_bsr(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1998 auto layout_to = kSparseBsr;
1999 if (self.layout() == layout_to) {
2000 _to_sparse_check_arguments("to_sparse_bsr", self, layout_to, blocksize, dense_dim_opt);
2001 return self;
2002 }
2003 return self._to_sparse_bsr(blocksize, dense_dim_opt);
2004 }
2005
to_sparse_bsc(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)2006 Tensor to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
2007 auto layout_to = kSparseBsc;
2008 if (self.layout() == layout_to) {
2009 _to_sparse_check_arguments("to_sparse_bsc", self, layout_to, blocksize, dense_dim_opt);
2010 return self;
2011 }
2012 return self._to_sparse_bsc(blocksize, dense_dim_opt);
2013 }
2014
2015 // Sparse layout conversions End
2016
to_meta(const Tensor & tensor)2017 Tensor to_meta(const Tensor& tensor) {
2018 auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
2019 /*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \
2020 /*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt);
2021 // needs to handle wrapped numbers, so dtype promotion works properly.
2022 if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
2023 out.unsafeGetTensorImpl()->set_wrapped_number(true);
2024 }
2025 return out;
2026 }
to_meta(const std::optional<Tensor> & tensor)2027 std::optional<Tensor> to_meta(const std::optional<Tensor>& tensor) {
2028 if (tensor.has_value()) {
2029 return to_meta(*tensor);
2030 }
2031 return std::nullopt;
2032 }
2033
to_meta(at::ITensorListRef t_list)2034 std::vector<Tensor> to_meta(at::ITensorListRef t_list) {
2035 std::vector<Tensor> outs;
2036 outs.reserve(t_list.size());
2037 for (const auto& tensor : t_list) {
2038 outs.push_back(to_meta(tensor));
2039 }
2040 return outs;
2041 }
2042 } // namespace at::native
2043