xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorConversions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/ATen.h>
3 #include <ATen/core/Tensor.h>
4 #include <optional>
5 #include <ATen/quantized/Quantizer.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/Parallel.h>
8 #include <ATen/TensorOperators.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_autocast_to_full_precision_native.h>
15 #include <ATen/ops/_autocast_to_reduced_precision_native.h>
16 #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
17 #include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
18 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
19 #include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
20 #include <ATen/ops/_sparse_bsc_tensor_unsafe_native.h>
21 #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
22 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
23 #include <ATen/ops/_sparse_coo_tensor_with_dims_native.h>
24 #include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
25 #include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
26 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
27 #include <ATen/ops/_to_copy.h>
28 #include <ATen/ops/_to_copy_native.h>
29 #include <ATen/ops/_to_cpu_native.h>
30 #include <ATen/ops/_to_dense_native.h>
31 #include <ATen/ops/_to_sparse_bsc_native.h>
32 #include <ATen/ops/_to_sparse_bsr_native.h>
33 #include <ATen/ops/_to_sparse_csc_native.h>
34 #include <ATen/ops/_to_sparse_csr_native.h>
35 #include <ATen/ops/_to_sparse_native.h>
36 #include <ATen/ops/arange_native.h>
37 #include <ATen/ops/empty.h>
38 #include <ATen/ops/empty_like.h>
39 #include <ATen/ops/empty_quantized.h>
40 #include <ATen/ops/empty_strided.h>
41 #include <ATen/ops/empty_strided_native.h>
42 #include <ATen/ops/to_dense_backward_native.h>
43 #include <ATen/ops/to_dense_native.h>
44 #include <ATen/ops/to_mkldnn_backward_native.h>
45 #include <ATen/ops/to_native.h>
46 #include <ATen/ops/to_sparse_bsc_native.h>
47 #include <ATen/ops/to_sparse_bsr_native.h>
48 #include <ATen/ops/to_sparse_csc_native.h>
49 #include <ATen/ops/to_sparse_csr_native.h>
50 #include <ATen/ops/to_sparse_native.h>
51 #include <ATen/ops/view_native.h>
52 #include <ATen/ops/zeros.h>
53 #endif
54 
55 #include <ATen/SparseCsrTensorUtils.h>
56 #include <ATen/core/ATen_fwd.h>
57 #include <ATen/native/IndexingUtils.h>
58 #include <ATen/native/NonSymbolicBC.h>
59 #include <ATen/native/SparseTensorUtils.h>
60 #include <ATen/native/TensorConversions.h>
61 #include <c10/core/impl/DeviceGuardImplInterface.h>
62 #include <algorithm>
63 #include <numeric>
64 
65 namespace at::native {
66 
67 namespace {
68 // dense_to_sparse_{csr,bsr,csc,bsc} common helpers
69 
70 // Preparation fo the N-D dense -> sparse compressed conversion.
71 // The N-D input is converted to 3-D (single batch dim) where we check that the
72 // product of batch dims is nonzero and for each batch the sparse matrix
73 // contained within has the same number of non-zero elements.
74 // The batches are joined along the compressed axis. The generation of indices
75 // for this matrix can be performed in a single step followed by a single step
76 // conversion to restore the batch dimension.
dense_to_sparse_compressed_prepare_check_mask_values_batched(const Layout & target_layout,Tensor & values,Tensor & mask,const int64_t & n_batch_dim)77 void dense_to_sparse_compressed_prepare_check_mask_values_batched(
78     const Layout& target_layout,
79     Tensor& values,
80     Tensor& mask,
81     const int64_t& n_batch_dim) {
82   if (n_batch_dim > 1) {
83     // For inputs with more than 1 batch dim we flatten them out.
84     // Input shape (b0, b1 ..., bn, r, c) -> (b0 * b1 * ... * bn, r ,c)
85     values = values.flatten(0, n_batch_dim - 1);
86     mask = mask.flatten(0, n_batch_dim - 1);
87   }
88 
89   // For informative messaging form the name of the function
90   // to_sparse_{csr,csc,bsr,bsc}.
91   TORCH_CHECK(
92       mask.size(0) > 0,
93       "to_sparse_",
94       // We want the message to match the function name so generate the
95       // lowercase acronym for the layout
96       sparse_csr::layoutToString(target_layout, false, true),
97       ": Expected product of batch dimensions to be non-zero.");
98 
99   // Compute the number of non-zero elements in the first batch, expand to full
100   // size
101   auto nse_per_batch = mask.select(0, 0).sum().expand(mask.size(0));
102   TORCH_CHECK(
103       mask.sum({-2, -1}).equal(nse_per_batch),
104       "Expect the same number of specified elements per batch.");
105 
106   // We need to join batches into a matrix increasing the length of the
107   // compressed axis. This allows us to create indices for a compressed matrix
108   // and de-batch them later (two kernels). Otherwise we would have to create
109   // indices for each batch individually requiring n_batch kernels. For csr/bsr,
110   // we already have the batch dim adjacent to the compressed axis and can
111   // flatten them together. For csc/bsc, we need to transpose first.
112   // For BSR/CSR (b, r, c) -> (b*r, c)
113   // For BSC/CSC (b, c, r) -> (r, b*c)
114   AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
115       target_layout,
116       "dense_to_sparse_compressed",
117       [&]() {
118         values = values.flatten(0, 1);
119         mask = mask.flatten(0, 1);
120       },
121       [&]() {
122         values = values.transpose(0, 1).flatten(1, 2);
123         mask = mask.transpose(0, 1).flatten(1, 2);
124       });
125 }
126 
127 // This function unfolds the compressed indices of a compressed sparse matrix
128 // into a batched compressed sparse tensor.
129 // This is analogous to an unflatten-like operation:
130 // unflatten(0, {b, r}) for csr/bsr with input shape (r*b, c)
131 //          (output shape (b, r, c))
132 // unflatten(1, {b, c}).transpose(0,1) for csc/bsc with input shape (r, c*b)
133 //          (output shape (r, b, c) unflatten, (b, r, c) unflatten + transpose)
134 // This only operates on the compressed indices as the plain indices and values
135 // can be manipulated as described above without special handling.
136 // It is a prerequisite for the conversion that the sparsity pattern is sane for
137 // the batched shape. That is each batch has the same number of nonzero
138 // elements.
compressed_to_batched_compressed_indices(const Tensor & compressed_in,const int64_t & n_batch,bool out_int32)139 Tensor compressed_to_batched_compressed_indices(
140     const Tensor& compressed_in,
141     const int64_t& n_batch,
142     bool out_int32) {
143   auto n_compressed_per_batch = (compressed_in.size(0) - 1) / n_batch;
144   ScalarType out_type = out_int32 ? ScalarType::Int : ScalarType::Long;
145   auto batched_out = at::zeros(
146       {n_batch, n_compressed_per_batch + 1},
147       compressed_in.options().dtype(out_type));
148 
149   // If the compressed dimension has length zero there is 1 element in each
150   // batch and it is zero we already have this result formed
151   if (n_compressed_per_batch > 0) {
152     // Slice the compressed indices ignoring the leading 0 element and reshape
153     // to n-batch rows
154     auto trailing_slice =
155         compressed_in.slice(0, 1, std::nullopt, 1).reshape({n_batch, -1});
156     // Slice the compressed indices again selecting the elements corresponding
157     // to the batch boundary. The values here will be increasing multiples of
158     // nnz per batch. Reshape to n-batch rows (1 col) for broadcasting.
159     // This is equivalent to arange(n_batch) * nnz_per_batch with the same
160     // reshape
161     auto offsets = compressed_in.slice(0, 0, -1, n_compressed_per_batch)
162                        .reshape({n_batch, -1});
163     // Subtracting the offsets from each row of the reshaped compressed indices
164     // gives us the compressed indices within the batch. The leading element of
165     // each row is not computed as it is always zero.  We copy into the view on
166     // the output buffer.
167     batched_out.narrow(-1, 1, n_compressed_per_batch)
168         .copy_(trailing_slice - offsets);
169   }
170   return batched_out;
171 }
172 
173 // After generating member tensors for sparse_compressed matrix, if the target
174 // shape is N-D we must reform the batch dimensions.
175 // Single kernel is used to restore one batch dimension in the compressed
176 // indices. From there full batch shape is restored by reshape. No special
177 // handling is needed for restoring batch dimensions of the values or
178 // plain_indices it can be done with reshape/unflatten.
reshape_2d_sparse_compressed_members_to_nd_batched(const IntArrayRef full_sizes,const int64_t & n_batch_dim,Tensor & compressed_indices,Tensor & plain_indices,Tensor & values)179 void reshape_2d_sparse_compressed_members_to_nd_batched(
180     const IntArrayRef full_sizes,
181     const int64_t& n_batch_dim,
182     Tensor& compressed_indices,
183     Tensor& plain_indices,
184     Tensor& values) {
185   auto batch_shape = full_sizes.slice(0, n_batch_dim);
186   auto n_batch = std::accumulate(
187       batch_shape.begin(), batch_shape.end(), 1, std::multiplies<int64_t>());
188   // NOTE: using this conversion requires the nnz per batch is the same for all
189   // batches that will be formed. We ensured this was the case on the way in so
190   // it is safe to use this conversion.
191   compressed_indices = compressed_to_batched_compressed_indices(
192       compressed_indices, n_batch, /*out_int32*/ false);
193 
194   // We can infer the last dim of the reshape targets, it will be nnz or
195   // nrow/ncol+1 depending on the layout and member tensor targeted.
196   auto batchsize_infer_last = DimVector(batch_shape);
197   batchsize_infer_last.push_back(-1);
198 
199   // -1 will be nnz per batch
200   plain_indices = plain_indices.reshape(batchsize_infer_last);
201   // -1 will be ncols (bsc,csc) or nrows (bsr,csr) + 1
202   compressed_indices = compressed_indices.reshape(batchsize_infer_last);
203   // -1 will be nnz (per batch).
204   // Note: Unflatten rather than reshape as it will work
205   // for both blocked and unblocked layouts. reshape works for unblocked layouts
206   // only
207   values = values.unflatten(0, batchsize_infer_last);
208 }
209 } // namespace
210 
211 // Take a Device that may not have device_index set (i.e., having it as -1
212 // representing the current device) and return the corresponding Device
213 // according to the actual device at the time of this function call.  No-op
214 // if the device_index is set.
ensure_has_index(Device device)215 static inline Device ensure_has_index(Device device) {
216   if (device.is_cpu() || device.has_index()) {
217     return device;
218   }
219   const c10::impl::DeviceGuardImplInterface* impl = c10::impl::getDeviceGuardImpl(device.type());
220   return impl->getDevice();
221 }
222 
ensure_has_index(std::optional<Device> device)223 static inline std::optional<Device> ensure_has_index(std::optional<Device> device) {
224   if (!device.has_value()) {
225     return std::nullopt;
226   }
227   return ensure_has_index(device.value());
228 }
229 
_to_copy(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,std::optional<c10::MemoryFormat> optional_memory_format)230 Tensor _to_copy(
231     const Tensor& self,
232     std::optional<ScalarType> dtype,
233     std::optional<Layout> layout,
234     std::optional<Device> device,
235     std::optional<bool> pin_memory,
236     bool non_blocking,
237     std::optional<c10::MemoryFormat> optional_memory_format) {
238   TORCH_CHECK(!layout.has_value() || self.layout() == layout.value(),
239            "to(options) doesn't support converting to a different layout, "
240            "but got self.layout being ", self.layout(),
241            " and options.layout set as ", layout.value());
242   auto options = TensorOptions()
243     .dtype(dtype)
244     .layout(layout)
245     .device(device)
246     .pinned_memory(pin_memory);
247 
248   if (options.has_device()) {
249     options = options.device(ensure_has_index(options.device()));
250   }
251   // memory_format is handled separately due to MemoryFormat::Preserve logic
252   options = self.options().merge_in(options).memory_format(std::nullopt);
253   auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
254 
255   // TODO: Use the dispatcher for this.
256   // Currently there are unenumerated extensibility issues preventing this.
257   if (self.layout() == kSparse) {
258       TORCH_CHECK(
259           memory_format == MemoryFormat::Preserve,
260           "to(options): COO only supports memory format Preserve, but got ", memory_format,
261           " instead.");
262     if (options.device().is_meta()) {
263         return zeros_like(self, options);
264     }
265     auto indices = self._indices();
266     const auto new_indices = at::native::to(
267         indices,
268         indices.scalar_type(),
269         c10::kStrided,
270         device,
271         pin_memory,
272         non_blocking,
273         true, // force copy since we are in _to_copy
274         memory_format);
275     const auto new_values = at::native::to(
276         self._values(),
277         dtype,
278         c10::kStrided,
279         device,
280         pin_memory,
281         non_blocking,
282         true, // force copy since we are in _to_copy
283         memory_format);
284 
285     return at::_sparse_coo_tensor_unsafe(
286         new_indices,
287         new_values,
288         self.sizes(),
289         options, self.is_coalesced());
290   } else if (at::sparse_csr::is_sparse_compressed(self)) {
291       TORCH_CHECK(
292           memory_format == MemoryFormat::Preserve,
293           "to(options): ", at::sparse_csr::layoutToString(self.layout()),
294           " only supports memory format Preserve, but got ", memory_format,
295           " instead.");
296 
297       if (options.device().is_meta()) {
298         return zeros_like(self, options);
299       }
300 
301       auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(self);
302 
303       const auto new_values = at::native::to(
304           self.values(),
305           dtype,
306           c10::kStrided,
307           device,
308           pin_memory,
309           non_blocking,
310           true, // force copy since we are in _to_copy
311           memory_format);
312 
313       const auto new_compressed_indices = at::native::to(
314           compressed_indices,
315           compressed_indices.scalar_type(),
316           c10::kStrided,
317           device,
318           pin_memory,
319           non_blocking,
320           true, // force copy since we are in _to_copy
321           memory_format);
322 
323       const auto new_plain_indices = at::native::to(
324           plain_indices,
325           plain_indices.scalar_type(),
326           c10::kStrided,
327           device,
328           pin_memory,
329           non_blocking,
330           true, // force copy since we are in _to_copy
331           memory_format);
332 
333     return at::_sparse_compressed_tensor_unsafe(
334         new_compressed_indices,
335         new_plain_indices,
336         new_values,
337         self.sizes(),
338         options);
339   }
340 
341   bool pin_out = (non_blocking && (self.is_cuda() || self.is_privateuseone())
342                   && options.device().is_cpu() && (options.layout() == c10::kStrided));
343 
344   if (memory_format == MemoryFormat::Preserve) {
345     if (options.device().supports_as_strided()) {
346       if (self.is_non_overlapping_and_dense()) {
347         Tensor r;
348         if (self.is_quantized()) {
349           r = at::empty_quantized(self.sizes(), self, options);
350           at::QuantizerPtr quantizer = r.quantizer();
351           r.copy_(self, non_blocking);
352           set_quantizer_(r, quantizer);
353         } else {
354           r = at::empty_strided(
355               self.sizes(),
356               self.strides(),
357               options.pinned_memory(pin_out));
358           r.copy_(self, non_blocking);
359         }
360         return r;
361       } else if (!self.is_quantized() && self.layout() == kStrided) {
362           Tensor r;
363           auto strides = infer_dense_strides(self.sizes(), self.strides());
364           r = at::empty_strided(
365               self.sizes(),
366               strides,
367               options.pinned_memory(pin_out));
368           r.copy_(self, non_blocking);
369           return r;
370       } else {
371         memory_format = self.suggest_memory_format();
372       }
373     } else {
374       memory_format = self.suggest_memory_format();
375     }
376   }
377   // See Note [Explicit nullopt MemoryFormat argument]
378   // TODO: empty_quantized does not work here. It raises an exception in CheckMemoryFormat.h prior to
379   // empty_affine_quantized/_empty_per_channel_affine_quantized calls
380   // at::empty also does not work here because there is no proper at::empty support for quantized tensors
381   // as it would return a quantized tensor with an UnknownQuantizer
382   auto r = self.is_quantized() ? at::empty_like(self, memory_format)
383                                : at::empty_symint(self.sym_sizes(),
384                                  options.memory_format(memory_format).pinned_memory(pin_out), std::nullopt);
385   r.copy_(self, non_blocking);
386   return r;
387 }
388 
389 template <typename T>
is_null_or_equal_to(const std::optional<T> & test,const T & value)390 static inline bool is_null_or_equal_to(const std::optional<T>& test, const T& value) {
391   if (!test.has_value()) {
392     return true;
393   }
394   return test.value() == value;
395 }
396 
397 // NOTE: static runtime's to_maybe_copy_out relies on details of this
398 // check; if you change how it works, please update static runtime as
399 // well.
to_will_alias(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)400 bool to_will_alias(
401     const Tensor& self,
402     std::optional<ScalarType> dtype,
403     std::optional<Layout> layout,
404     std::optional<Device> device,
405     bool copy,
406     std::optional<c10::MemoryFormat> optional_memory_format) {
407   auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
408 
409   return is_null_or_equal_to(dtype, self.dtype().toScalarType()) &&
410     is_null_or_equal_to(layout, self.layout()) &&
411     is_null_or_equal_to(device, self.device()) &&
412     !copy &&
413     (memory_format == MemoryFormat::Preserve ||
414      self.suggest_memory_format() == memory_format);
415 }
416 
to_impl(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)417 static inline Tensor to_impl(
418     const Tensor& self,
419     std::optional<ScalarType> dtype,
420     std::optional<Layout> layout,
421     std::optional<Device> device,
422     std::optional<bool> pin_memory,
423     bool non_blocking,
424     bool copy,
425     std::optional<c10::MemoryFormat> optional_memory_format) {
426 
427   // fast path
428   if (to_will_alias(self, dtype, layout, device, copy, optional_memory_format)) {
429     return self;
430   }
431   return at::_to_copy(
432       self, dtype, layout, device, pin_memory, non_blocking, optional_memory_format);
433 }
434 
435 // If input tensor is fp32, cast it to fp16, otherwise leave it alone.
436 // (this is intended to be used internally by the JIT autocast implementation)
_autocast_to_reduced_precision(const Tensor & self,bool cuda_enabled,bool cpu_enabled,ScalarType cuda_dtype,ScalarType cpu_dtype)437 Tensor _autocast_to_reduced_precision(const Tensor& self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) {
438   if (self.dtype() == at::ScalarType::Float &&
439       ((self.device().is_cuda() && cuda_enabled) ||
440       (self.device().is_cpu() && cpu_enabled))
441       ) {
442     at::ScalarType target = at::ScalarType::Undefined;
443     if (self.device().is_cuda()) {
444       target = cuda_dtype;
445     } else if (self.device().is_cpu()) {
446       target = cpu_dtype;
447     }
448 
449     TORCH_INTERNAL_ASSERT(target != at::ScalarType::Undefined, "_autocast_to_reduced_precision requires legit ScalarType argument for given device");
450 
451     return to_impl(
452         self, target, std::nullopt, std::nullopt, std::nullopt, false, false, std::nullopt);
453   } else {
454     return self;
455   }
456 }
457 
458 // If input tensor is fp16, cast it to fp32, otherwise leave it alone.
459 // (this is intended to be used internally by the JIT autocast implementation)
_autocast_to_full_precision(const Tensor & self,bool cuda_enabled,bool cpu_enabled)460 Tensor _autocast_to_full_precision(const Tensor& self, bool cuda_enabled, bool cpu_enabled) {
461   if ((self.dtype() == at::ScalarType::Half || self.dtype() == at::ScalarType::BFloat16) &&
462       ((self.device().is_cuda() && cuda_enabled) ||
463       (self.device().is_cpu() && cpu_enabled))
464       ) {
465     return to_impl(
466         self, at::ScalarType::Float, std::nullopt, std::nullopt, std::nullopt, false, false, std::nullopt);
467   } else {
468     return self;
469   }
470 }
471 
to(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)472 Tensor to(
473   const Tensor& self,
474     std::optional<ScalarType> dtype,
475     std::optional<Layout> layout,
476     std::optional<Device> device,
477     std::optional<bool> pin_memory,
478   bool non_blocking,
479   bool copy,
480   std::optional<c10::MemoryFormat> optional_memory_format
481 ) {
482   return to_impl(
483       self,
484       dtype,
485       layout,
486       ensure_has_index(device),
487       pin_memory,
488       non_blocking,
489       copy,
490       optional_memory_format);
491 }
492 
to(const Tensor & self,Device device,ScalarType dtype,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)493 Tensor to(const Tensor& self, Device device, ScalarType dtype, bool non_blocking, bool copy, std::optional<c10::MemoryFormat> optional_memory_format) {
494   return to_impl(
495       self,
496       dtype,
497       std::nullopt,
498       ensure_has_index(device),
499       std::nullopt,
500       non_blocking,
501       copy,
502       optional_memory_format);
503 }
504 
to(const Tensor & self,ScalarType dtype,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)505 Tensor to(const Tensor& self, ScalarType dtype, bool non_blocking, bool copy, std::optional<c10::MemoryFormat> optional_memory_format) {
506   return to_impl(
507       self,
508       dtype,
509       std::nullopt,
510       std::nullopt,
511       std::nullopt,
512       non_blocking,
513       copy,
514       optional_memory_format);
515 }
516 
to(const Tensor & self,const Tensor & other,bool non_blocking,bool copy,std::optional<c10::MemoryFormat> optional_memory_format)517 Tensor to(const Tensor& self, const Tensor& other, bool non_blocking, bool copy, std::optional<c10::MemoryFormat> optional_memory_format) {
518   auto options = other.options();
519   return to_impl(
520       self,
521       options.dtype().toScalarType(),
522       options.layout(),
523       options.device(),
524       options.pinned_memory(),
525       non_blocking,
526       copy,
527       optional_memory_format);
528 }
529 
530 // This op is important primarily for lazy / graph-based backends.
531 // While this vanilla implementation loops through each tensor and independently converts it to cpu,
532 // a lazy backend like XLA might need to tell sync updates across tensors.
_to_cpu(TensorList tensors)533 std::vector<Tensor> _to_cpu(TensorList tensors) {
534     std::vector<Tensor> cpu_tensors;
535     for (const auto& t : tensors) {
536         cpu_tensors.push_back(t.cpu());
537     }
538     return cpu_tensors;
539 }
540 
to_dense_backward(const Tensor & grad,const Tensor & input_,std::optional<bool> masked_grad_)541 Tensor to_dense_backward(const Tensor& grad, const Tensor& input_, std::optional<bool> masked_grad_) {
542   /*
543     For historical reasons, to_dense backward implements masked
544     semantics for sparse tensors, that is, gradients with respect to
545     unspecified elements are ignored.  The masked_grad kw argument of
546     to_dense is introduced to allow to_dense to be used in the
547     non-masked semantics context. However, for BC reasons, the default
548     value to masked_grad kw argument is set True as a first instance.
549     Eventually, we should eliminate the masked_grad kw argument and
550     let to_dense backward to behave according to non-masked
551     semantics. Masked semantics of tensors is implemented in the
552     framework of masked tensors.
553   */
554   const auto input_layout = input_.layout();
555   const bool masked_grad = masked_grad_.value_or(true);
556   switch (input_layout) {
557     case kStrided:
558       // TODO: return grad as it is
559       return grad.to_dense(input_.scalar_type(), masked_grad_);
560     case kSparse:
561       // Autograd operates on the coalesced assumption, i.e. no duplicate values.
562       if (masked_grad) {
563         return grad.sparse_mask(input_.coalesce());
564       } else {
565         // TODO: return grad as it is
566         return grad.to_sparse(input_.sparse_dim());
567       }
568     case kSparseCsr:
569     case kSparseCsc:
570       // TODO: add efficient CSR/CSC support for sparse_mask
571       if (masked_grad) {
572         return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout);
573       } else {
574         // TODO: return grad as it is
575         return grad.to_sparse(input_layout, /*blocksize=*/std::nullopt, /*dense_dim=*/input_.dense_dim());
576       }
577     case kSparseBsr:
578     case kSparseBsc: {
579       // TODO: add efficient BSR/BSC support for sparse_mask
580       const auto blocksize = at::sparse_csr::getBlockSize(input_);
581       if (masked_grad) {
582         return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout, blocksize);
583       } else {
584         // TODO: return grad as it is
585         return grad.to_sparse(input_layout, blocksize, input_.dense_dim());
586       }
587     }
588     case kMkldnn:
589       return grad.to_mkldnn(input_.scalar_type());
590     default:
591       AT_ERROR("to_dense_backward: Unsupported input layout: ", input_layout);
592       return Tensor{};
593   }
594 }
595 
to_mkldnn_backward(const Tensor & grad,const Tensor & input_)596 Tensor to_mkldnn_backward(const Tensor& grad, const Tensor& input_) {
597   AT_ASSERT(input_.layout() == c10::kStrided);
598   return grad.to_dense(input_.scalar_type());
599 }
600 
to_dense(const Tensor & tensor,std::optional<c10::ScalarType> dtype,std::optional<bool> masked_grad)601 Tensor to_dense(const Tensor& tensor, std::optional<c10::ScalarType> dtype, std::optional<bool> masked_grad) {
602   if (tensor.layout() == c10::kSparse) {
603     return tensor._to_dense(dtype, masked_grad);
604   }
605   if (tensor.layout() == c10::kSparseCsr ||
606       tensor.layout() == c10::kSparseCsc ||
607       tensor.layout() == c10::kSparseBsr ||
608       tensor.layout() == c10::kSparseBsc) {
609     return tensor._to_dense(dtype, masked_grad);
610   }
611   if (tensor.layout() == c10::kMkldnn) {
612     return tensor._to_dense(dtype, masked_grad);
613   }
614   TORCH_CHECK(
615       tensor.layout() == c10::kStrided,
616       "to_dense does not support layout ",
617       tensor.layout());
618   if (dtype) {
619     return tensor.to(*dtype);
620   }
621   return tensor;
622 }
623 
sparse_to_dense(const Tensor & self,std::optional<ScalarType> dtype,std::optional<bool> masked)624 Tensor sparse_to_dense(const Tensor& self, std::optional<ScalarType> dtype, std::optional<bool> masked) {
625   TORCH_CHECK(
626       !dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
627   Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
628   return dst.add_(self);
629 }
630 
sparse_compressed_to_dense(const Tensor & self,std::optional<ScalarType> dtype,std::optional<bool> masked_grad)631 Tensor sparse_compressed_to_dense(
632     const Tensor& self,
633     std::optional<ScalarType> dtype,
634     std::optional<bool> masked_grad) {
635   TORCH_CHECK(
636       !dtype.has_value(),
637       "dtype argument is not supported by sparse_csr_to_dense");
638 
639   if (self.numel() == 0) {
640     return at::zeros(self.sizes(), self.options().layout(kStrided));
641   }
642 
643   auto batch_ndim = sparse_csr::numBatchDimensions(self);
644 
645   auto compressed_rows = self.layout() == kSparseCsr || self.layout() == kSparseBsr;
646   auto block_sparse = self.layout() == kSparseBsr || self.layout() == kSparseBsc;
647 
648   auto [compressed_indices, plain_indices] =
649       sparse_csr::getCompressedPlainIndices(self);
650 
651   auto values = self.values();
652   Tensor dense = at::zeros(self.sizes(), self.options().layout(kStrided));
653 
654   if (batch_ndim == 0) {
655     // Pad shape so we can treat non-batched like batched, we will
656     // squeeze out the phantom batch dim at the end.
657     compressed_indices.unsqueeze_(0);
658     plain_indices.unsqueeze_(0);
659     values.unsqueeze_(0);
660     dense.unsqueeze_(0);
661   }
662   if (batch_ndim > 1) {
663     // Flatten batch dims
664     compressed_indices = compressed_indices.flatten(0, batch_ndim - 1);
665     plain_indices = plain_indices.flatten(0, batch_ndim - 1);
666     values = values.flatten(0, batch_ndim - 1);
667     dense = dense.flatten(0, batch_ndim - 1);
668   }
669 
670   // At this point there is only one batch dim, existed already or was
671   // flattened from multiple batch dims.  Now, reshape the resulting
672   // dense matrix so that this single batch dim is joined with sparse
673   // dims into a single dim, so that the remaining dims are only block
674   // dims eventually, and then dense dims.
675   auto n_batch = values.size(0);
676   int64_t nrows = 0, ncols = 0;
677   auto dense_reshaped_sizes = dense.sizes().vec();
678   if (!block_sparse) {
679     nrows = self.size(batch_ndim);
680     ncols = self.size(batch_ndim + 1);
681     dense_reshaped_sizes.erase(dense_reshaped_sizes.begin(), dense_reshaped_sizes.begin() + 2);
682   } else {
683     std::array<int64_t, 2> blocksize = {values.size(2), values.size(3)};
684     nrows = self.size(batch_ndim) / blocksize[0];
685     ncols = self.size(batch_ndim + 1) / blocksize[1];
686     dense_reshaped_sizes[1] = blocksize[0];
687     dense_reshaped_sizes[2] = blocksize[1];
688   }
689   dense_reshaped_sizes[0] = n_batch * nrows * ncols;
690   dense = dense.reshape(dense_reshaped_sizes);
691 
692   // Calculate batch, row and column indices for non-zeros in the
693   // sparse matrix, and use these to calculate corresponding indices
694   // into the dense matrix reshaped as above.  Then, update dense
695   // matrix by adding sparse matrix values into elements with indices
696   // calculated this way.
697   auto options = compressed_indices.options();
698   auto nnz_per_batch = values.size(1);
699   auto batch_indices = at::arange(0, n_batch, options).repeat_interleave(nnz_per_batch);
700   auto ncompressed = compressed_rows ? nrows : ncols;
701   auto compressed_indices_over_all_batches =
702     at::cat({compressed_indices.slice(1, 0, ncompressed).flatten()
703             + nnz_per_batch * at::arange(0, n_batch, options).repeat_interleave(ncompressed),
704             n_batch * nnz_per_batch * at::ones({1}, options)});
705   Tensor indices = at::_convert_indices_from_csr_to_coo(
706       compressed_indices_over_all_batches,
707       plain_indices.flatten(),
708       false,
709       !compressed_rows);
710   auto row_indices = indices.select(0, 0);
711   auto col_indices = indices.select(0, 1);
712   if (compressed_rows) {
713     row_indices -= batch_indices * nrows;
714   } else {
715     col_indices -= batch_indices * ncols;
716   }
717   auto offsets = col_indices + row_indices * ncols + batch_indices * nrows * ncols;
718   dense.index_add_(0, offsets, values.flatten(0, 1));
719 
720   // Un-tile the result.  The final reshape uses the original
721   // self.sizes() which will squeeze out the extra batch dim if we put
722   // one in.
723   if (!block_sparse) {
724     return dense.reshape(self.sizes());
725   } else {
726     return dense
727       .unflatten(0, {-1, nrows, ncols})
728         .transpose(2, 3)
729         .reshape(self.sizes());
730   }
731 }
732 
733 // Computes the strides for view_dtype output when the view dtype is
734 // smaller than the original dtype
compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides,int64_t size_ratio,ScalarType old_dtype,ScalarType new_dtype)735 inline SymDimVector compute_strides_for_view_dtype_downsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
736   const int64_t ndim = old_strides.size();
737 
738   TORCH_CHECK(
739     old_strides[ndim - 1] == 1,
740     "self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
741     " (different element sizes), but got ", old_strides[ndim - 1]);
742 
743   SymDimVector new_strides(ndim);
744   for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
745     new_strides[dim_idx] = old_strides[dim_idx] * size_ratio;
746   }
747   new_strides[ndim - 1] = 1;
748   return new_strides;
749 }
750 
751 // Computes the strides for view_dtype output when the view dtype is
752 // larger than the original dtype
compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides,int64_t size_ratio,ScalarType old_dtype,ScalarType new_dtype)753 inline SymDimVector compute_strides_for_view_dtype_upsize(SymIntArrayRef old_strides, int64_t size_ratio, ScalarType old_dtype, ScalarType new_dtype) {
754   const int64_t ndim = old_strides.size();
755   TORCH_CHECK(
756     old_strides[ndim - 1] == 1,
757     "self.stride(-1) must be 1 to view ", old_dtype, " as ", new_dtype,
758     " (different element sizes), but got ", old_strides[ndim - 1]);
759 
760   SymDimVector new_strides(ndim);
761   for (int64_t dim_idx = 0; dim_idx < ndim - 1; dim_idx++) {
762     TORCH_CHECK(
763       (old_strides[dim_idx] % size_ratio) == 0,
764       "self.stride(", dim_idx, ") must be divisible by ", size_ratio,
765       " to view ", old_dtype, " as ", new_dtype, " (different element sizes), ",
766       "but got ", old_strides[dim_idx]);
767 
768     new_strides[dim_idx] = old_strides[dim_idx] / size_ratio;
769   }
770   new_strides[ndim - 1] = 1;
771   return new_strides;
772 }
773 
view_dtype(const Tensor & self,ScalarType dtype)774 Tensor view_dtype(const Tensor& self, ScalarType dtype) {
775   if (self.scalar_type() == dtype) {
776     return self;
777   }
778   const auto type_meta = c10::scalarTypeToTypeMeta(dtype);
779   TORCH_CHECK(!self.is_conj(),
780     "torch.Tensor.view is not supported for conjugate view tensors when converting to a different dtype.");
781   TORCH_CHECK(!self.is_neg(),
782     "torch.Tensor.view is not supported for tensors with negative bit set when converting to a different dtype.");
783 
784   int64_t self_element_size = self.element_size();
785   int64_t new_element_size = static_cast<int64_t>(type_meta.itemsize());
786 
787   Storage storage = self.storage();
788   auto new_tensor = detail::make_tensor<TensorImpl>(
789       std::move(storage), self.key_set(), type_meta);
790   auto* impl = new_tensor.unsafeGetTensorImpl();
791 
792   if (self_element_size == new_element_size) {
793     impl->set_sizes_and_strides(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
794 
795   } else if (self.dim() == 0) {
796     TORCH_CHECK(false,
797       "self.dim() cannot be 0 to view ", self.scalar_type(), " as ",
798       dtype, " (different element sizes)");
799 
800   } else if (self_element_size > new_element_size) {
801     // Downsizing element size
802 
803     int64_t size_ratio = self_element_size / new_element_size;
804     auto new_strides = compute_strides_for_view_dtype_downsize(
805       self.sym_strides(), size_ratio, self.scalar_type(), dtype);
806 
807     auto old_sizes = self.sym_sizes();
808     SymDimVector new_sizes(self.dim());
809     std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
810     new_sizes[self.dim() - 1] *= size_ratio;
811 
812     auto new_storage_offset = size_ratio * self.sym_storage_offset();
813 
814     impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
815 
816   } else {
817     // Upsizing element size
818 
819     int64_t size_ratio = new_element_size / self_element_size;
820 
821     TORCH_CHECK(
822       (self.sym_size(-1) % size_ratio) == 0,
823       "self.size(-1) must be divisible by ", size_ratio, " to view ",
824       self.scalar_type(), " as ", dtype, " (different element sizes), ",
825       "but got ", self.sym_size(-1));
826 
827     TORCH_CHECK(
828       (self.sym_storage_offset() % size_ratio) == 0,
829       "self.storage_offset() must be divisible by ", size_ratio, " to view ",
830       self.scalar_type(), " as ", dtype, " (different element sizes), but got ",
831       self.sym_storage_offset());
832 
833     auto new_strides = compute_strides_for_view_dtype_upsize(
834       self.sym_strides(), size_ratio, self.scalar_type(), dtype);
835 
836     auto old_sizes = self.sym_sizes();
837     SymDimVector new_sizes(self.dim());
838     std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
839     new_sizes[self.dim() - 1] /= size_ratio;
840 
841     auto new_storage_offset = self.sym_storage_offset() / size_ratio;
842 
843     impl->set_sizes_and_strides(new_sizes, new_strides, new_storage_offset);
844   }
845 
846   return new_tensor;
847 }
848 
_tile_tensor(const Tensor & self,IntArrayRef blocksize)849 static Tensor _tile_tensor(const Tensor& self, IntArrayRef blocksize) {
850   // This code turns a matrix into a sequence of blocks
851   //
852   // Given matrix
853   //
854   //  1  2  3  4
855   //  5  6  7  8
856   //  9 10 11 12
857   // 14 15 16 17
858   //
859   // _tile_tensor(matrix, {2, 2}) will yield the following 2 by 2 blocks
860   //
861   //  1  2 |  3  4 |  9 10 | 11 12
862   //  5  6 |  7  8 | 14 15 | 16 17
863   //
864   //  via a 4D Tensor of shape (2, 2, 2, 2)
865   //
866   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[0] > 0);
867   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[1] > 0);
868   auto block_size_0 = self.size(0) / blocksize[0];
869   auto block_size_1 = self.size(1) / blocksize[1];
870 
871   auto new_shape = DimVector({block_size_0, blocksize[0], block_size_1, blocksize[1]});
872   new_shape.append(DimVector(self.sizes().slice(2, self.dim() - 2)));
873   return self.reshape(new_shape)
874       .transpose(1, 2)
875       .contiguous();
876 }
877 
_batch_tile_tensor(const Tensor & self,IntArrayRef blocksize,const int64_t dense_dim)878 static Tensor _batch_tile_tensor(const Tensor& self, IntArrayRef blocksize, const int64_t dense_dim) {
879   if (self.dim() == 2 + dense_dim) {
880     return _tile_tensor(self, blocksize);
881   }
882   auto n_batch_dim = self.dim() - 2 - dense_dim;
883   // Same as _tile_tensor, just per matrix entry of self, if self is 3D.
884   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[0] > 0);
885   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(blocksize[1] > 0);
886   auto block_size_0 = self.size(n_batch_dim) / blocksize[0];
887   auto block_size_1 = self.size(n_batch_dim + 1) / blocksize[1];
888   auto tiled_sizes = DimVector(self.sizes().slice(0, n_batch_dim));
889   tiled_sizes.push_back(block_size_0);
890   tiled_sizes.push_back(blocksize[0]);
891   tiled_sizes.push_back(block_size_1);
892   tiled_sizes.push_back(blocksize[1]);
893   tiled_sizes.append(DimVector(self.sizes().slice(n_batch_dim + 2, dense_dim)));
894   return self.reshape(tiled_sizes).transpose(n_batch_dim + 1, n_batch_dim + 2).contiguous();
895 }
896 
_mask_to_indices(const Tensor & mask)897 static Tensor _mask_to_indices(const Tensor& mask) {
898   // This function returns a vector of the indices at which given
899   // boolean mask is True. at::nonzero can achieve the same, but
900   // we yet have to compare the performance difference.
901   TORCH_CHECK(mask.dim() == 1, "Currently _mask_to_indices only supports 1-d masks.");
902   TORCH_CHECK(mask.dtype() == at::kBool, "Expected mask to be of dtype bool.");
903   return at::native::arange(
904       mask.numel(), at::kLong, kStrided, mask.device())
905       .masked_select(mask);
906 }
907 
_not_zero_mask_to_col_row_indices(Tensor not_zero_mask,ScalarType index_dtype,Device index_device)908 static std::pair<Tensor, Tensor> _not_zero_mask_to_col_row_indices(
909     Tensor not_zero_mask,
910     ScalarType index_dtype,
911     Device index_device) {
912   auto col_indices =
913       at::native::arange(not_zero_mask.size(-1), index_dtype, kStrided, index_device)
914           .view({1, not_zero_mask.size(-1)})
915           .expand_as(not_zero_mask)
916           .masked_select(not_zero_mask);
917   auto row_indices =
918       at::native::arange(
919           not_zero_mask.size(-2), index_dtype, kStrided, index_device)
920           .view({not_zero_mask.size(-2), 1})
921           .expand_as(not_zero_mask)
922           .masked_select(not_zero_mask);
923   return std::pair<Tensor, Tensor>(col_indices, row_indices);
924 }
925 
926 // Sparse layout conversions Start
927 
928 static inline
_to_sparse_check_arguments(const std::string & funcname,const Tensor & self,const int64_t sparse_dim)929 void _to_sparse_check_arguments(const std::string& funcname, const Tensor& self, const int64_t sparse_dim) {
930   auto layout_from = self.layout();
931 
932   auto layout_from_valid = layout_from == kStrided || layout_from == kSparse || at::sparse_csr::is_sparse_compressed(layout_from);
933   if (!layout_from_valid) {
934     AT_ERROR(funcname, ": unexpected source layout ", layout_from);
935   }
936 
937   if (layout_from == kStrided) {
938     if (sparse_dim == 0 && self.dim() > 0) {
939       AT_ERROR(funcname, ": sparse_dim argument must be in >0 when self.dim()>0");
940     }
941     if (sparse_dim < 0 || sparse_dim > self.dim()) {
942       AT_ERROR(funcname, ": sparse_dim argument must be in [0,", self.dim(), "] range, but ", sparse_dim, " is given");
943     }
944   } else if (layout_from == kSparse) {
945     if (sparse_dim != self.sparse_dim()) {
946       AT_ERROR(funcname, ": conversion from ", layout_from, " to ", kSparse, " with sparse_dim argument !=self.sparse_dim() is not supported");
947     }
948   } else if (at::sparse_csr::is_sparse_compressed(layout_from)) {
949     if (sparse_dim != 2) {
950       AT_ERROR(funcname, ": conversion from ", layout_from, " to ", kSparse, " with sparse_dim argument !=2 is not supported");
951     }
952   }
953 }
954 
955 static inline
_to_sparse_check_arguments(const std::string & funcname,const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)956 void _to_sparse_check_arguments(const std::string& funcname, const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
957   auto layout_from = self.layout();
958   auto layout_to = layout.value_or(kSparse);
959 
960   auto layout_from_valid = layout_from == kStrided || layout_from == kSparse || at::sparse_csr::is_sparse_compressed(layout_from);
961   if (!layout_from_valid) {
962     AT_ERROR(funcname, ": unexpected source layout ", layout_from);
963   }
964   auto layout_to_valid = layout_to == kStrided || layout_to == kSparse || at::sparse_csr::is_sparse_compressed(layout_to);
965   if (!layout_to_valid) {
966     AT_ERROR(funcname, ": unexpected source layout ", layout_from);
967   }
968 
969   if (layout_from == kSparse && layout_to != kSparse) {
970     if (self.sparse_dim() != 2) {
971       AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " for input tensors with sparse_dim()!=2 is not supported");
972     }
973   }
974 
975   if ((layout_from == kSparseCsr || layout_from == kSparseCsc) &&
976       (layout_to == kSparseBsr || layout_to == kSparseBsc)) {
977     if (sparse_csr::numBatchDimensions(self) > 0) {
978       AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " for batched inputs is not supported");
979     }
980   }
981 
982   if (blocksize.has_value()) {
983     if (blocksize.value().size() != 2) {
984       AT_ERROR(funcname, ": blocksize needs to be a tuple of size 2, but got ", blocksize.value().size());
985     }
986     auto blocksize_to = *blocksize;
987     if (blocksize_to[0] <= 0 || blocksize_to[1] <= 0) {
988       AT_ERROR(funcname, ": blocksize needs to be positive, but got ", blocksize_to);
989     }
990 
991     if (layout_to == kSparseBsr || layout_to == kSparseBsc) {
992       if (layout_from == kSparseBsr || layout_from == kSparseBsc) {
993         auto blocksize_from = at::sparse_csr::getBlockSize(self);
994         if (!(blocksize_to == blocksize_from)) {
995           AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with blocksize changed from ", blocksize_from, " to ", blocksize_to, " is not supported");
996         }
997       } else {
998         auto dense_dim = (layout_from == kStrided) ? dense_dim_opt.value_or(0) : self.dense_dim();
999         auto sparse_row_dim = -(dense_dim + 2);
1000         auto sparse_col_dim = -(dense_dim + 1);
1001         if ((self.size(sparse_row_dim) % blocksize_to[0] != 0) ||
1002             (self.size(sparse_col_dim) % blocksize_to[1] != 0)) {
1003             AT_ERROR(funcname, ": tensor sparse size (", self.size(sparse_row_dim), ",", self.size(sparse_row_dim), ") must be divisible by given blocksize (", blocksize_to[0], ",", blocksize_to[1], ")");
1004         }
1005       }
1006     } else {
1007       AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with blocksize argument given is not supported");
1008     }
1009   } else {
1010     if ((layout_to == kSparseBsr || layout_to == kSparseBsc) &&
1011         !(layout_from == kSparseBsr && layout_from == kSparseBsc)) {
1012       AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " without blocksize argument given is not supported");
1013     }
1014   }
1015 
1016   if (dense_dim_opt.has_value()) {
1017     if (layout_from != kStrided) {
1018       AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with dense_dim argument given is not supported");
1019     }
1020 
1021     auto dense_dim = *dense_dim_opt;
1022     if (layout_to == kSparse) {
1023       if (dense_dim == self.dim() && self.dim() > 0) {
1024         AT_ERROR(funcname, ": dense_dim argument must be !=self.dim() when self.dim()>0");
1025       }
1026       if (dense_dim < 0 || dense_dim > self.dim()) {
1027         AT_ERROR(funcname, ": dense_dim argument must be in [0,", self.dim(), "] range, but ", dense_dim, " is given");
1028       }
1029     } else {
1030       if (dense_dim < 0 || dense_dim > self.dim() - 2) {
1031         AT_ERROR(funcname, ": dense_dim argument must be in [0,", self.dim() - 2, "] range, but ", dense_dim, " is given");
1032       }
1033     }
1034   }
1035 }
1036 
1037 template<Layout target_layout>
dense_to_sparse_compressed(const Tensor & self,const Tensor & self_mask,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1038 static Tensor dense_to_sparse_compressed(const Tensor& self, const Tensor& self_mask, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1039   static_assert(target_layout == Layout::SparseCsr || target_layout == Layout::SparseCsc
1040                 || target_layout == Layout::SparseBsr || target_layout == Layout::SparseBsc,
1041                 "invalid layout template parameter for dense_to_sparse_compressed");
1042   constexpr auto compressed_rows_layout = target_layout == Layout::SparseCsr || target_layout == Layout::SparseBsr;
1043   constexpr auto blocked_layout = target_layout == Layout::SparseBsr || target_layout == Layout::SparseBsc;
1044 
1045   int64_t dense_dim = dense_dim_opt.value_or(0);
1046 
1047   // Reshape values so that the block dims are explicitly added, and
1048   // calculate a mask tensor that has only batch and sparse dims, and
1049   // value true whenever sparse matrix has a non-zero element over
1050   // corresponding block and dense dims, and false otherwise.
1051   auto n_batch_dim = self.dim() - 2 - dense_dim;
1052   auto is_batched = n_batch_dim > 0;
1053   auto values = blocked_layout ? _batch_tile_tensor(self, blocksize, dense_dim) :  self;
1054   auto not_zero_mask = blocked_layout ? _batch_tile_tensor(self_mask, blocksize, dense_dim) : self_mask;
1055   if (blocked_layout || dense_dim > 0) {
1056     std::vector<int64_t> reduce_dim((blocked_layout ? 2 : 0) + dense_dim);
1057     std::iota(reduce_dim.begin(), reduce_dim.end(), n_batch_dim + 2);
1058     not_zero_mask = not_zero_mask.sum(reduce_dim) != 0;
1059   }
1060 
1061   if (is_batched) {
1062     // Prepare for the conversion, in particular join the batch dims
1063     // and the compressed dim into the single dim.
1064     dense_to_sparse_compressed_prepare_check_mask_values_batched(
1065         target_layout, values, not_zero_mask, n_batch_dim);
1066   }
1067 
1068   // Calculate sparse matrix row and col indices and then, depending
1069   // on the target layout, corresponding compressed and sparse
1070   // indices.  Use the mask tensor calculate above to generate sparse
1071   // matrix values tensor.
1072   Tensor row_indices;
1073   Tensor col_indices;
1074   Tensor compressed_indices;
1075   if (compressed_rows_layout) {
1076     std::tie(col_indices, row_indices) = _not_zero_mask_to_col_row_indices(
1077         not_zero_mask, at::kLong, not_zero_mask.device());
1078     compressed_indices = at::_convert_indices_from_coo_to_csr(
1079         row_indices, not_zero_mask.size(0), false /*out_int32*/);
1080     {
1081       auto mask_indices = _mask_to_indices(not_zero_mask.flatten());
1082       values = values.flatten(0, 1).index_select(0, mask_indices);
1083     }
1084   } else {
1085     std::tie(row_indices, col_indices) = _not_zero_mask_to_col_row_indices(
1086        not_zero_mask.transpose(1, 0), at::kLong, not_zero_mask.device());
1087     compressed_indices = at::_convert_indices_from_coo_to_csr(
1088         col_indices, not_zero_mask.size(-1), false /*out_int32*/);
1089     {
1090       auto mask_indices = _mask_to_indices(not_zero_mask.transpose(0, 1).flatten());
1091       values = values.transpose(0, 1).flatten(0, 1).index_select(0, mask_indices);
1092     }
1093   }
1094   Tensor& plain_indices = compressed_rows_layout ? col_indices : row_indices;
1095 
1096   if (is_batched) {
1097    // Restore the batch dims and compressed dim.
1098     reshape_2d_sparse_compressed_members_to_nd_batched(
1099         self.sizes(), n_batch_dim, compressed_indices, plain_indices, values);
1100   }
1101 
1102   // Create compressed sparse matrix with the target layout.
1103   return at::_sparse_compressed_tensor_unsafe(
1104         compressed_indices,
1105         plain_indices,
1106         values,
1107         self.sizes(),
1108         self.options().layout(target_layout));
1109 }
1110 
dense_to_sparse_with_mask(const Tensor & self,const Tensor & mask,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1111 Tensor dense_to_sparse_with_mask(const Tensor& self, const Tensor& mask, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1112   auto layout_to = layout.value_or(kSparse);
1113   TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "dense_to_sparse: unexpected same input and output layout");
1114   TORCH_INTERNAL_ASSERT(self.layout() == mask.layout(),
1115                         "dense_to_sparse_with_mask: expected mask layout ", self.layout(), ", got ", mask.layout());
1116   TORCH_INTERNAL_ASSERT(self.sizes() == mask.sizes(),
1117                         "dense_to_sparse_with_mask: expected mask size ", self.sizes(), ", got ", mask.sizes());
1118   _to_sparse_check_arguments("dense_to_sparse_with_mask", self, layout, blocksize, dense_dim_opt);
1119 
1120   switch (layout_to) {
1121   case kSparse:
1122     return self.sparse_mask(mask.to_sparse(self.dim() - dense_dim_opt.value_or(0)));
1123   case kSparseCsr:
1124     return dense_to_sparse_compressed<Layout::SparseCsr>(self, mask, {}, dense_dim_opt);
1125   case kSparseCsc:
1126     return dense_to_sparse_compressed<Layout::SparseCsc>(self, mask, {}, dense_dim_opt);
1127   case kSparseBsr:
1128     return dense_to_sparse_compressed<Layout::SparseBsr>(self, mask, *blocksize, dense_dim_opt);
1129   case kSparseBsc:
1130     return dense_to_sparse_compressed<Layout::SparseBsc>(self, mask, *blocksize, dense_dim_opt);
1131   default:
1132     break;
1133   }
1134 
1135   AT_ERROR("dense_to_sparse_with_mask: ", self.layout(), " to ", layout_to, " conversion not supported");
1136   return Tensor{};
1137 }
1138 
dense_to_sparse_csr(const Tensor & self,std::optional<int64_t> dense_dim_opt)1139 Tensor dense_to_sparse_csr(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1140   auto layout_to = kSparseCsr;
1141   _to_sparse_check_arguments("dense_to_sparse_csr", self, layout_to, {}, dense_dim_opt);
1142 
1143   return dense_to_sparse_compressed<Layout::SparseCsr>(self, self != 0, {}, dense_dim_opt);
1144 }
1145 
dense_to_sparse_csc(const Tensor & self,std::optional<int64_t> dense_dim_opt)1146 Tensor dense_to_sparse_csc(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1147   auto layout_to = kSparseCsc;
1148   _to_sparse_check_arguments("dense_to_sparse_csc", self, layout_to, {}, dense_dim_opt);
1149 
1150   return dense_to_sparse_compressed<Layout::SparseCsc>(self, self != 0, {}, dense_dim_opt);
1151 }
1152 
dense_to_sparse_bsr(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1153 Tensor dense_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1154   auto layout_to = kSparseBsr;
1155   _to_sparse_check_arguments("dense_to_sparse_bsr", self, layout_to, blocksize, dense_dim_opt);
1156 
1157   return dense_to_sparse_compressed<Layout::SparseBsr>(self, self != 0, blocksize, dense_dim_opt);
1158 }
1159 
dense_to_sparse_bsc(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1160 Tensor dense_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1161   auto layout_to = kSparseBsc;
1162   _to_sparse_check_arguments("dense_to_sparse_bsc", self, layout_to, blocksize, dense_dim_opt);
1163 
1164   return dense_to_sparse_compressed<Layout::SparseBsc>(self, self != 0, blocksize, dense_dim_opt);
1165 }
1166 
dense_to_sparse(const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1167 Tensor dense_to_sparse(const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1168   auto layout_to = layout.value_or(kSparse);
1169   TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "dense_to_sparse: unexpected same input and output layout");
1170   _to_sparse_check_arguments("dense_to_sparse", self, layout, blocksize, dense_dim_opt);
1171 
1172   switch (layout_to) {
1173   case kSparse:
1174     return self.to_sparse(self.dim() - dense_dim_opt.value_or(0));
1175   case kSparseCsr:
1176     return self.to_sparse_csr(dense_dim_opt);
1177   case kSparseCsc:
1178     return self.to_sparse_csc(dense_dim_opt);
1179   case kSparseBsr:
1180     return self.to_sparse_bsr(*blocksize, dense_dim_opt);
1181   case kSparseBsc:
1182     return self.to_sparse_bsc(*blocksize, dense_dim_opt);
1183   default:
1184     break;
1185   }
1186 
1187   AT_ERROR("dense_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported");
1188   return Tensor{};
1189 }
1190 
dense_to_sparse(const Tensor & self,int64_t sparse_dim)1191 Tensor dense_to_sparse(const Tensor& self, int64_t sparse_dim) {
1192   _to_sparse_check_arguments("dense_to_sparse", self, sparse_dim);
1193 
1194   int64_t dims = self.dim();
1195   at::TensorOptions sparse_options = self.options().layout(kSparse);
1196   std::vector<int64_t> sizes = self.sizes().vec();
1197   Tensor nz = self.nonzero().transpose(0, 1);
1198   if (nz.size(1) == 0) {
1199     auto sparse = new_with_dims_sparse(
1200         sparse_dim,
1201         dims - sparse_dim,
1202         sizes,
1203         optTypeMetaToScalarType(sparse_options.dtype_opt()),
1204         sparse_options.layout_opt(),
1205         sparse_options.device_opt(),
1206         sparse_options.pinned_memory_opt());
1207     return sparse._coalesced_(true);
1208   }
1209   Tensor indices;
1210   if (sparse_dim == dims) {
1211     indices = nz.clone();
1212   } else {
1213     Tensor i = nz.narrow(0, 0, sparse_dim);
1214     std::tie(indices, std::ignore, std::ignore) = unique_dim(i, 1);
1215     indices = indices.contiguous(); // many sparse CUDA kernels require
1216                                     // contiguity, see issue #12633
1217   }
1218 
1219   Tensor values;
1220   if (self.dim() > 0) {
1221     auto ix = toListOfOptionalTensors(indices.chunk(indices.size(0), 0));
1222     values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve);
1223   } else {
1224     AT_ASSERT(nz.sizes().equals({0, 1}));
1225     // In this cases, indices is a clone of nz, which is a tensor of shape (0,
1226     // 1). Given sparse tensor invariants, values should be shape (1,)
1227     values = self.unsqueeze(0).clone(at::MemoryFormat::Preserve);
1228   }
1229 
1230   Tensor sparse = at::sparse_coo_tensor(indices, values, sizes, sparse_options);
1231   return sparse._coalesced_(true);
1232 }
1233 
sparse_compressed_to_flipped(const Tensor & self,std::optional<IntArrayRef> blocksize,const std::string & name)1234 static Tensor sparse_compressed_to_flipped(
1235     const Tensor& self,
1236     std::optional<IntArrayRef> blocksize,
1237     const std::string& name) {
1238   const auto layout = self.layout();
1239   // NOTE: errors on non-compressed sparse layouts.
1240   const auto flipped_layout = at::sparse_csr::flip_compressed_layout(layout);
1241 
1242   // Suppose compressed_indices represent rows of an input in either
1243   // CSR or BSR sparse compressed format.
1244   // In order to convert a batched CSR/BSR index into a batched CSC/BSC index
1245   // we perform the following steps:
1246   // 1. Convert a sparse compressed index representing batches of matrices of
1247   //    shape (b, r, c) to a sparse compressed index that represents a single
1248   //    matrix of shape (b * r, c).
1249   // 2. Turn the compressed indices of the matrix of shape (b * r, c) into
1250   //    COO indices.
1251   // 3. Map these COO indices into the COO indices of a matrix of shape (r, b * c)
1252   //    such that if A is a matrix of shape (b * r, c) and B is a matrix of shape
1253   //    (r, b * c) such that
1254   //    A[(k * r):(k * r + r), :] = B[:, (k * c):(k * c + c)] for all k in arange(b),
1255   //    then A[i, j] = B[i', j'].
1256   //    This is equivalent to finding indices that match values of matrices
1257   //    tiled vertically to values of the same matrices tiled horizontally.
1258   // 4. Convert the COO indices to the CSC/BSC indices and form the output.
1259   //
1260   // NOTE: the reason behind vertical/horizontal tiling is to be able to transform
1261   //       indices over all matrices in the batch in a single kernel call, since
1262   //       all the existing coo <-> compressed indices conversion methods assume
1263   //       a single matrix.
1264   //
1265   // CSC/BSC inputs are handled in a similar fashion with a "transposed" argument.
1266   // See the comments below for detailed explanations on how exactly each step
1267   // is performed.
1268 
1269   Tensor compressed_indices, plain_indices;
1270   std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(self);
1271   auto values = self.values();
1272   const auto nnz = plain_indices.size(-1);
1273 
1274   const auto n_batches = compressed_indices.dim() - 1;
1275   auto n_batches_nonzero = n_batches;
1276   // Insert fake batch dim for simplicity
1277   if (!n_batches) {
1278     n_batches_nonzero = 1;
1279     compressed_indices.unsqueeze_(0);
1280     plain_indices.unsqueeze_(0);
1281     values.unsqueeze_(0);
1282   }
1283 
1284   // NOTE: these sparse_dim are true sparse dims only for CSR/CSC
1285   // inputs.  And for BSR/BSC these are <true sparse dims> /
1286   // <blocksize>.  In other words, sparse_dim stores ranges of valid
1287   // indices in the row/col dims.
1288   const auto sparse_dim = [&]() -> at::DimVector {
1289     auto sparse_dim = at::DimVector(self.sizes().slice(n_batches, 2));
1290     if (layout == at::kSparseBsr || layout == at::kSparseBsc) {
1291       auto blocksize = at::sparse_csr::getBlockSize(self);
1292       sparse_dim[0] /= blocksize[0];
1293       sparse_dim[1] /= blocksize[1];
1294     }
1295     return sparse_dim;
1296   }();
1297 
1298   // batch_sizes_nonempty stores at least one, potentially fake, batch dimension.
1299   // rebatch_sizes_nonempty is equivalent to batch_sizes_nonempty.push_back(-1),
1300   // and is used to unflatten batch dimensions from a dimension of size
1301   // (batch_numel * dim_size,) for some dim_size.
1302   const auto batch_sizes_nonempty = at::DimVector(plain_indices.sizes().slice(0, n_batches_nonzero));
1303   auto rebatch_sizes_nonempty = at::DimVector(batch_sizes_nonempty);
1304   rebatch_sizes_nonempty.push_back(-1);
1305   const auto batch_numel_nonzero = std::accumulate(
1306       batch_sizes_nonempty.begin(),
1307       batch_sizes_nonempty.begin() + n_batches_nonzero,
1308       1,
1309       std::multiplies<int64_t>());
1310 
1311   // Equivalent to (arange(batch_numel_nonzero).mul_(nnz)).reshape(batch_sizes_nonempty).
1312   // We just compute it differently to use `add` kernel in place of `mul` for better
1313   // performance.
1314   const auto batch_nnz_offset = [&]() -> Tensor {
1315     const auto wrapped_nnz = at::tensor({nnz}, compressed_indices.options());
1316     auto offset = wrapped_nnz
1317       .expand({batch_numel_nonzero})
1318       .cumsum(-1).sub_(wrapped_nnz)
1319       .reshape(batch_sizes_nonempty);
1320     return offset;
1321   }();
1322 
1323   // Step 1 for CSR/BSR inputs:
1324   // Convert a sparse compressed index representing batches of matrices of
1325   // shape (b, r, c) to a sparse compressed index that represents a single
1326   // matrix of shape (b * r, c).
1327   // The algorithm is identical for CSC/BSC inputs, with the batch dimensions
1328   // flattened in the "transposed" dimension.
1329   const auto compressed_indices_2d = [&]() -> Tensor {
1330     // Extract offsets only relevant for the first :-1 elements in a row/col.
1331     const auto compressed_offsets = compressed_indices.slice(-1, 0, -1);
1332     // batch_offsets offsets each individual matrix row/col offsets by the total
1333     // sum of nnz's of all the matrices with the smaller batch index.
1334     const auto batch_offsets = batch_nnz_offset
1335       .unsqueeze(-1).expand_as(compressed_offsets);
1336     // compressed_offsets + batch_offsets creates an offset vector for a 2d matrix
1337     // that is stored in a compressed sparse format.
1338     const auto compressed_offsets_2d = compressed_offsets.add(batch_offsets).reshape({-1});
1339     const auto offsets_len = compressed_offsets_2d.numel();
1340     auto res = at::empty({offsets_len + 1}, compressed_indices.options());
1341     res.slice(-1, 0, -1).copy_(compressed_offsets_2d);
1342     // By appending nnz * batch_numel_nonzero to (compressed_offsets + batch_offsets)
1343     // a compressed index of a 2d matrix is formed.
1344     res.slice(-1, -1).fill_(nnz * batch_numel_nonzero);
1345     return res;
1346   }();
1347   // More involved for compressed indices, but pretty easy for plain_indices and values:
1348   // just squash batch dimensions.
1349   const auto plain_indices_2d = plain_indices.flatten(0, n_batches_nonzero);
1350   // NOTE: values are not 2d! They just represent values of a sparse compressed 2d matrix.
1351   const auto values_2d = values.flatten(0, n_batches_nonzero);
1352 
1353   const auto is_out_int32 = compressed_indices.scalar_type() == ScalarType::Int;
1354 
1355   // Step 2 & 3:
1356   //
1357   // Turn the compressed indices of the matrix of shape (b * r, c) into COO indices.
1358   //
1359   // Map these COO indices into the COO indices of a matrix of shape (r, b * c)
1360   // such that if A is a matrix of shape (b * r, c) and B is a matrix of shape
1361   // (r, b * c) such that
1362   // A[(k * r):(k * r + r), :] = B[:, (k * c):(k * c + c)] for all k in arange(b),
1363   // then A[i, j] = B[i', j'].
1364   // This is equivalent to finding indices that match values of matrices
1365   // tiled vertically to values of the same matrices tiled horizontally.
1366 
1367   // coo <-> sparse index conversions assume CSR/BSR inputs.
1368   // To CSC/BSC inputs these indices will appear "transposed".
1369   const auto is_transposed_indices = layout == at::kSparseCsc || layout == at::kSparseBsc;
1370   const auto coo_indices_2d_transposed = [&]() -> Tensor {
1371     auto coo_indices_2d = _convert_indices_from_csr_to_coo(
1372         compressed_indices_2d,
1373         plain_indices_2d,
1374         is_out_int32,
1375         /*transpose=*/true); // Flip rows/cols for convenience.
1376     // Convert COO indices of (b * r, c) to (r, b * c).
1377     // It is a map (i, j) -> {
1378     //    b = i // r
1379     //    i' = i % r
1380     //    j' = j + b * c
1381     //    return (i', j')
1382     // }
1383     // NOTE: we used transposed=true above!
1384     auto i = coo_indices_2d.select(0, 1);
1385     auto j = coo_indices_2d.select(0, 0);
1386     auto b = i.div(is_transposed_indices ? sparse_dim[1] : sparse_dim[0], "trunc");
1387     // Modify i, j in-place.
1388     i.fmod_(is_transposed_indices ? sparse_dim[1] : sparse_dim[0]);
1389     j.add_(b * (is_transposed_indices ? sparse_dim[0] : sparse_dim[1]));
1390     return coo_indices_2d;
1391   }();
1392 
1393   // Step 4:
1394   // Convert the COO indices to the CSC/BSC indices and form the output.
1395   // We need to sort COO indices along the "tranposed" dim to satisfy the
1396   // invariant of sorted plain indices.
1397   // Hash coo indices by converting 2d indices to linear offsets with
1398   // more "weight" (aka stride) placed on the "transposed" dimension.
1399   const auto coo_indices_2d_transposed_hashed = at::sparse::flatten_indices(
1400       coo_indices_2d_transposed,
1401       is_transposed_indices ? at::DimVector({sparse_dim[0], sparse_dim[1] * batch_numel_nonzero})
1402                             : at::DimVector({sparse_dim[1], sparse_dim[0] * batch_numel_nonzero}));
1403   const auto hash_argsort = std::get<1>(coo_indices_2d_transposed_hashed.sort());
1404   const auto coo_indices_2d_transposed_sorted = coo_indices_2d_transposed.index_select(1, hash_argsort);
1405 
1406   const auto new_compressed_indices_coo_2d = coo_indices_2d_transposed_sorted.select(0, 0);
1407   const auto new_plain_indices_2d = coo_indices_2d_transposed_sorted.select(0, 1);
1408   const auto new_values_2d = values_2d.index_select(0, hash_argsort);
1409 
1410   auto new_compressed_indices = compressed_to_batched_compressed_indices(
1411       _convert_indices_from_coo_to_csr(
1412         new_compressed_indices_coo_2d,
1413         is_transposed_indices
1414           ? batch_numel_nonzero * sparse_dim[0]
1415           : batch_numel_nonzero * sparse_dim[1],
1416         is_out_int32),
1417       batch_numel_nonzero,
1418       is_out_int32)
1419     .unflatten(0, batch_sizes_nonempty);
1420   auto new_plain_indices = new_plain_indices_2d.unflatten(0, rebatch_sizes_nonempty);
1421   auto new_values = new_values_2d.unflatten(0, rebatch_sizes_nonempty);
1422   // Kill fake batch dim if it was inserted.
1423   if (!n_batches) {
1424     new_compressed_indices.squeeze_(0);
1425     new_plain_indices.squeeze_(0);
1426     new_values.squeeze_(0);
1427   }
1428 
1429   return _sparse_compressed_tensor_unsafe(
1430       new_compressed_indices,
1431       new_plain_indices,
1432       new_values,
1433       self.sizes(),
1434       self.options().layout(flipped_layout));
1435 }
1436 
sparse_compressed_to_sparse_csr(const Tensor & self,std::optional<int64_t> dense_dim_opt)1437 Tensor sparse_compressed_to_sparse_csr(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1438   auto layout_to = kSparseCsr;
1439   TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse_csr: unexpected same input and output layout");
1440   _to_sparse_check_arguments("sparse_compressed_to_sparse_csr", self, layout_to, {}, dense_dim_opt);
1441 
1442   if (self.layout() == kSparseCsc) {
1443     return sparse_compressed_to_flipped(self, std::nullopt, "to_sparse_csr");
1444   }
1445 
1446   AT_ERROR("sparse_compressed_to_sparse_csr: expected SparseCsr or SparseCsc layout but got ", self.layout());
1447   return Tensor{};
1448 }
1449 
sparse_compressed_to_sparse_csc(const Tensor & self,std::optional<int64_t> dense_dim_opt)1450 Tensor sparse_compressed_to_sparse_csc(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1451   auto layout_to = kSparseCsc;
1452   TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse_csc: unexpected same input and output layout");
1453   _to_sparse_check_arguments("sparse_compressed_to_sparse_csc", self, layout_to, {}, dense_dim_opt);
1454 
1455   if (self.layout() == kSparseCsr) {
1456     return sparse_compressed_to_flipped(self, std::nullopt, "to_sparse_csc");
1457   }
1458 
1459   AT_ERROR("sparse_compressed_to_sparse_csc: expected SparseCsr or SparseCsc layout but got ", self.layout());
1460   return Tensor{};
1461 }
1462 
coo_to_sparse_csr(const Tensor & self,std::optional<int64_t> dense_dim_opt)1463 Tensor coo_to_sparse_csr(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1464   auto layout_to = kSparseCsr;
1465   _to_sparse_check_arguments("coo_to_sparse_csr", self, layout_to, {}, dense_dim_opt);
1466 
1467   auto coalesced_self = self.coalesce();
1468   auto row_indices = coalesced_self.indices()[0];
1469   bool out_int32 = (row_indices.scalar_type() == at::kInt);
1470   auto crow_indices = at::_convert_indices_from_coo_to_csr(
1471       row_indices, self.size(0), out_int32);
1472   return at::native::_sparse_csr_tensor_unsafe(
1473       crow_indices,
1474       coalesced_self.indices()[1].contiguous(),
1475       coalesced_self.values(),
1476       coalesced_self.sizes(),
1477       coalesced_self.scalar_type(),
1478       c10::kSparseCsr,
1479       coalesced_self.device());
1480 }
1481 
coo_to_sparse_csc(const Tensor & self,std::optional<int64_t> dense_dim_opt)1482 Tensor coo_to_sparse_csc(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1483   auto layout_to = kSparseCsc;
1484   _to_sparse_check_arguments("coo_to_sparse_csc", self, layout_to, {}, dense_dim_opt);
1485 
1486   auto transposed_csr = self.transpose(0, 1).to_sparse_csr(dense_dim_opt);
1487   return at::native::_sparse_csc_tensor_unsafe(
1488       transposed_csr.crow_indices(),
1489       transposed_csr.col_indices(),
1490       transposed_csr.values(),
1491       self.sizes(),
1492       transposed_csr.scalar_type(),
1493       c10::kSparseCsc,
1494       transposed_csr.device());
1495 }
1496 
coo_to_sparse_bsr(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1497 Tensor coo_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1498   auto layout_to = kSparseBsr;
1499   _to_sparse_check_arguments("coo_to_sparse_bsr", self, layout_to, blocksize, dense_dim_opt);
1500 
1501   return self.to_sparse_csr(dense_dim_opt).to_sparse_bsr(blocksize);
1502 }
1503 
coo_to_sparse_bsc(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1504 Tensor coo_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1505   auto layout_to = kSparseBsc;
1506   _to_sparse_check_arguments("coo_to_sparse_bsc", self, layout_to, blocksize, dense_dim_opt);
1507 
1508   return self.to_sparse_csc(dense_dim_opt).to_sparse_bsc(blocksize);
1509 }
1510 
1511 namespace {
1512 template <typename input_t, typename output_t>
convert_indices_from_coo_to_csr_cpu(const Tensor & result,const Tensor & input,const int64_t size)1513 void convert_indices_from_coo_to_csr_cpu(
1514     const Tensor& result,
1515     const Tensor& input,
1516     const int64_t size) {
1517   int64_t numel = input.numel();
1518   const input_t* data_in = input.const_data_ptr<input_t>();
1519   output_t* data_out = result.data_ptr<output_t>();
1520 
1521   if (numel == 0) {
1522     result.zero_();
1523     return;
1524   }
1525 
1526   for (int64_t i = 0; i <= data_in[0]; i++)
1527     data_out[i] = static_cast<output_t>(0);
1528 
1529   at::parallel_for(
1530       0, numel - 1, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1531         input_t curr_value = data_in[start], next_value;
1532         for (const auto i : c10::irange(start, end)) {
1533           next_value = data_in[i + 1];
1534           for (; curr_value < next_value; curr_value++)
1535             data_out[curr_value + 1] = static_cast<output_t>(i + 1);
1536         }
1537       });
1538   for (int64_t i = data_in[numel - 1] + 1; i < size + 1; i++) {
1539     data_out[i] = static_cast<output_t>(numel);
1540   }
1541 }
1542 
1543 template <typename input_t, typename output_t>
convert_indices_from_csr_to_coo_cpu(const Tensor & indices,const Tensor & crow_indices,const Tensor & col_indices,const bool transpose=false)1544 void convert_indices_from_csr_to_coo_cpu(
1545     const Tensor& indices,
1546     const Tensor& crow_indices,
1547     const Tensor& col_indices,
1548     const bool transpose = false) {
1549   int64_t nrows = crow_indices.size(-1) - 1;
1550   int64_t nnz = col_indices.size(-1);
1551   if (nrows == 0 || nnz == 0) {
1552     indices.zero_();  // is this needed as indices has a zero-valued
1553                       // dimension when nrows or nnz is 0?
1554     return;
1555   }
1556   auto crow_indices_ = crow_indices.expect_contiguous();
1557   int64_t total_nnz = col_indices.numel();
1558   int64_t batch_ndim = crow_indices.dim() - 1;
1559   if (batch_ndim > 0) {
1560     auto batch_indices = indices.narrow(0, 0, batch_ndim);
1561     batch_indices.copy_(at::sparse::full_coo_indices(crow_indices.sizes().slice(0, batch_ndim), crow_indices.options())
1562                         .repeat_interleave(nnz, 1));
1563   }
1564   const input_t* crow_indices_data_in = crow_indices_->const_data_ptr<input_t>();
1565   TORCH_INTERNAL_ASSERT(indices.is_contiguous());
1566   auto row0 = indices.select(0, transpose ? batch_ndim + 1 : batch_ndim + 0);
1567   auto row1 = indices.select(0, transpose ? batch_ndim + 0 : batch_ndim + 1);
1568   output_t* data_out = row0.data_ptr<output_t>();
1569   auto col_indices_ = col_indices.expect_contiguous();
1570   row1.copy_(col_indices_->view({-1}));
1571   at::parallel_for(
1572                    0, nrows * total_nnz / nnz, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1573         for (const auto i_  : c10::irange(start, end)) {
1574           auto b = i_ / nrows;
1575           auto i = i_ % nrows;
1576           std::fill(
1577               &data_out[b * nnz + crow_indices_data_in[b * (nrows + 1) + i]],
1578               &data_out[b * nnz + crow_indices_data_in[b * (nrows + 1) + i + 1]],
1579               static_cast<output_t>(i));
1580         }
1581       });
1582 }
1583 } // namespace
1584 
TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_cpu)1585 TORCH_IMPL_FUNC(_convert_indices_from_coo_to_csr_structured_cpu)
1586 (const Tensor& input,
1587  const int64_t size,
1588  const bool out_int32,
1589  const Tensor& result) {
1590   if (out_int32) {
1591     AT_DISPATCH_INTEGRAL_TYPES(
1592         input.scalar_type(), "convert_indices_from_coo_to_csr_cpu", [&] {
1593           convert_indices_from_coo_to_csr_cpu<scalar_t, int32_t>(
1594               result, input, size);
1595         });
1596   } else {
1597     AT_DISPATCH_INTEGRAL_TYPES(
1598         input.scalar_type(), "convert_indices_from_coo_to_csr_cpu", [&] {
1599           convert_indices_from_coo_to_csr_cpu<scalar_t, int64_t>(
1600               result, input, size);
1601         });
1602   }
1603 }
1604 
TORCH_IMPL_FUNC(_convert_indices_from_csr_to_coo_structured_cpu)1605 TORCH_IMPL_FUNC(_convert_indices_from_csr_to_coo_structured_cpu)
1606 (const Tensor& crow_indices,
1607  const Tensor& col_indices,
1608  const bool out_int32,
1609  const bool transpose,
1610  const Tensor& result) {
1611   if (out_int32) {
1612     AT_DISPATCH_INTEGRAL_TYPES(
1613         crow_indices.scalar_type(), "convert_indices_from_csr_to_coo_cpu", [&] {
1614           convert_indices_from_csr_to_coo_cpu<scalar_t, int32_t>(
1615               result, crow_indices, col_indices, transpose);
1616         });
1617   } else {
1618     AT_DISPATCH_INTEGRAL_TYPES(
1619         crow_indices.scalar_type(), "convert_indices_from_csr_to_coo_cpu", [&] {
1620           convert_indices_from_csr_to_coo_cpu<scalar_t, int64_t>(
1621               result, crow_indices, col_indices, transpose);
1622         });
1623   }
1624 }
1625 
1626 /*
1627  * Based on
1628  * https://github.com/scipy/scipy/blob/8a64c938ddf1ae4c02a08d2c5e38daeb8d061d38/scipy/sparse/sparsetools/csr.h
1629  * Modified to ensure sorted BSR column indices.
1630  */
1631 template <class index_t, class scalar_t, bool compressed_rows>
_compressed_to_block_compressed_cpu_kernel(const index_t n_compressed,const index_t n_plain,const index_t C,const index_t P,const index_t D,const index_t * input_compressed_indices,const index_t * input_plain_indices,const scalar_t * input_values,index_t * result_compressed_indices,index_t * result_plain_indices,scalar_t * result_values)1632 void _compressed_to_block_compressed_cpu_kernel(
1633     const index_t n_compressed, // Tensor size along compressed dimension
1634     const index_t n_plain, // Tensor size along plain dimension
1635     const index_t C, // Block size along compressed dimensions
1636     const index_t P, // Block size along plain dimension
1637     const index_t D, // Number of elements in dense dimensions
1638     const index_t* input_compressed_indices,
1639     const index_t* input_plain_indices,
1640     const scalar_t* input_values,
1641     index_t* result_compressed_indices,
1642     index_t* result_plain_indices,
1643     scalar_t* result_values) {
1644   // All blocks are possible, that is, may be allocated if a single
1645   // non-zero value lives within them. Otherwise they're not.
1646 
1647   // Allocate pointers for all possible plain blocks plus 1
1648   std::vector<scalar_t*> blocks(n_plain / P + 1, nullptr);
1649 
1650   assert(n_compressed % C == 0);
1651   assert(n_plain % P == 0);
1652 
1653   // Number of blocks along compressed dim
1654   index_t n_bcompressed = n_compressed / C;
1655   // Number of blocks along plain_dim
1656   index_t n_bplain = n_plain / P;
1657 
1658   // Number of elements per block
1659   index_t CPD = C * P * D;
1660   // Number of blocks overall
1661   index_t n_blks = 0;
1662 
1663   result_compressed_indices[0] = 0;
1664 
1665   // Iterate over blocks along compressed dim
1666   for (index_t block_c = 0; block_c < n_bcompressed; block_c++) {
1667     // Iterate over blocks along plain dim to locate non-zero blocks,
1668     // this guarantees sorted plain dim indices
1669     for (index_t block_p = 0; block_p < n_bplain; block_p ++) {
1670       for (index_t i = input_compressed_indices[C * block_c]; i < input_compressed_indices[C * (block_c + 1)]; i++) {
1671         index_t p = input_plain_indices[i]; // plain dim element index
1672         if (p / P == block_p) {
1673           blocks[block_p] = result_values + CPD * n_blks;
1674           result_plain_indices[n_blks] = block_p;
1675           n_blks++;
1676           break;
1677         }
1678       }
1679     }
1680 
1681     // Iterate over compressed dim within block
1682     for (index_t cb = 0; cb < C; cb++) {
1683       index_t c = C * block_c + cb; // compressed dim index
1684       for (index_t i = input_compressed_indices[c]; i < input_compressed_indices[c + 1]; i++) {
1685         index_t p = input_plain_indices[i]; // plain dim index
1686 
1687         // Block corresponding to plain dim index
1688         index_t block_p = p / P;
1689         // Plain dim index within block
1690         index_t pb = p % P;
1691 
1692         // Specific blocks entries should not be visited more than
1693         // once.  Scipy code does an addition here. Why?
1694         // A possible answer: Scipy code supports "uncoalesced CSR"
1695         // format that allows repeated plain dim indices, and
1696         // compressed and plain indices may be unsorted.
1697         std::copy(input_values + i * D, input_values + (i + 1) * D,
1698                   blocks[block_p] + (compressed_rows ? P * cb + pb : C * pb + cb) * D);
1699       }
1700     }
1701 
1702     // Scipy code has
1703     /*
1704       for (I i = input_compressed_indices[C * block_c];
1705            i < input_compressed_indices[C * (block_c + 1)];
1706            i++) {
1707              blocks[input_plain_indices[i] / P] = 0;
1708            }
1709     */
1710     // but we don't need it because the modified code (see the block_p
1711     // loop above) does not need to evaluate `blocks[block_p] == 0`
1712     // that the original code did.
1713     result_compressed_indices[block_c + 1] = n_blks;
1714   }
1715 }
1716 
1717 /*
1718  * Based on
1719  * https://github.com/scipy/scipy/blob/8a64c938ddf1ae4c02a08d2c5e38daeb8d061d38/scipy/sparse/sparsetools/csr.h
1720  */
1721 template <class index_t>
compressed_count_blocks(const index_t n_compressed,const index_t n_plain,const index_t C,const index_t P,const index_t Ac[],const index_t Ap[])1722 index_t compressed_count_blocks(
1723     const index_t n_compressed, // Tensor size along compressed dimension
1724     const index_t n_plain, // Tensor size along plain dimension
1725     const index_t C, // Block size along compressed dimensions
1726     const index_t P, // Block size along plain dimension
1727     const index_t Ac[], // Compressed indices
1728     const index_t Ap[] // Plain indices
1729   ) {
1730   std::vector<index_t> mask(n_plain / P + 1, -1);
1731   index_t n_blks = 0;
1732   for (index_t c = 0; c < n_compressed; c++) {
1733     index_t bc = c / C;
1734     for (index_t i = Ac[c]; i < Ac[c + 1]; i++) {
1735       index_t bp = Ap[i] / P;
1736       if (mask[bp] != bc) {
1737         mask[bp] = bc;
1738         n_blks++;
1739       }
1740     }
1741   }
1742   return n_blks;
1743 }
1744 
1745 template<Layout target_layout>
_compressed_to_block_compressed_cpu(const Tensor & self,IntArrayRef blocksize)1746 Tensor _compressed_to_block_compressed_cpu(const Tensor& self, IntArrayRef blocksize) {
1747   static_assert(target_layout == Layout::SparseBsr || target_layout == Layout::SparseBsc,
1748                 "invalid layout template parameter for _compressed_to_block_compressed_cpu");
1749 
1750   auto input_values = self.values().contiguous();
1751   Tensor input_compressed_indices;
1752   Tensor input_plain_indices;
1753   std::tie(input_compressed_indices, input_plain_indices) = sparse_csr::getCompressedPlainIndices(self);
1754   input_compressed_indices = input_compressed_indices.contiguous();
1755   input_plain_indices = input_plain_indices.contiguous();
1756 
1757   // First we determine the number of blocks needed. For each given
1758   // block, if it contains a non-zero element we will allocate values
1759   // and indices for it.
1760   int64_t num_blocks = 0;
1761   auto compressed_dim = (target_layout == Layout::SparseBsr) ? self.size(0) : self.size(1);
1762   auto plain_dim = (target_layout == Layout::SparseBsr) ? self.size(1) : self.size(0);
1763   auto compressed_blocksize = (target_layout == Layout::SparseBsr) ? blocksize[0] : blocksize[1];
1764   auto plain_blocksize = (target_layout == Layout::SparseBsr) ? blocksize[1] : blocksize[0];
1765 
1766   AT_DISPATCH_INDEX_TYPES(
1767       input_compressed_indices.scalar_type(), "_compressed_to_block_compressed_cpu", [&] {
1768         num_blocks =
1769           compressed_count_blocks<index_t>(
1770               compressed_dim,
1771               plain_dim,
1772               compressed_blocksize,
1773               plain_blocksize,
1774               input_compressed_indices.data_ptr<index_t>(),
1775               input_plain_indices.data_ptr<index_t>());
1776       });
1777   DimVector dense_shape{input_values.sizes().slice(1, input_values.dim() - 1)};
1778   DimVector values_shape{num_blocks, blocksize[0], blocksize[1]};
1779   values_shape.append(dense_shape);
1780 
1781   Tensor result_values = input_values.new_zeros(values_shape);
1782   Tensor result_compressed_indices =
1783       input_compressed_indices.new_empty({compressed_dim /compressed_blocksize + 1});
1784   Tensor result_plain_indices = input_plain_indices.new_empty({num_blocks});
1785 
1786   // Next we copy over non-zero elements into the allocated blocks.
1787   auto n_dense = std::accumulate(
1788       dense_shape.begin(), dense_shape.end(), 1, std::multiplies<int64_t>());
1789   AT_DISPATCH_INDEX_TYPES(
1790       input_compressed_indices.scalar_type(), "_compressed_to_block_compressed_cpu", [&] {
1791         AT_DISPATCH_SPARSE_VALUE_TYPES(
1792             input_values.scalar_type(), "_compressed_to_block_compressed_cpu", [&] {
1793               _compressed_to_block_compressed_cpu_kernel<index_t, scalar_t, target_layout == Layout::SparseBsr>(
1794                   compressed_dim,
1795                   plain_dim,
1796                   compressed_blocksize,
1797                   plain_blocksize,
1798                   n_dense,
1799                   input_compressed_indices.data_ptr<index_t>(),
1800                   input_plain_indices.data_ptr<index_t>(),
1801                   input_values.data_ptr<scalar_t>(),
1802                   result_compressed_indices.data_ptr<index_t>(),
1803                   result_plain_indices.data_ptr<index_t>(),
1804                   result_values.data_ptr<scalar_t>());
1805             });
1806       });
1807 
1808   return at::_sparse_compressed_tensor_unsafe(
1809       result_compressed_indices,
1810       result_plain_indices,
1811       result_values,
1812       self.sizes(),
1813       self.options().layout(target_layout));
1814 }
1815 
sparse_compressed_to_sparse_bsr(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1816 Tensor sparse_compressed_to_sparse_bsr(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1817   auto layout_to = kSparseBsr;
1818   TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse_bsr: unexpected same input and output layout");
1819   _to_sparse_check_arguments("sparse_compressed_to_sparse_bsr", self, layout_to, blocksize, dense_dim_opt);
1820 
1821   if (self.layout() == kSparseBsc) {
1822     return sparse_compressed_to_flipped(self, blocksize, "to_sparse_bsr");
1823   }
1824   if (self.layout() == kSparseCsr) {
1825     if (self.device() != kCPU) {
1826       TORCH_WARN("sparse_compressed_to_sparse_bsr executing on the CPU device, the performance may be sub-optimal");
1827     }
1828     return _compressed_to_block_compressed_cpu<kSparseBsr>(self.cpu(), blocksize).to(self.device());
1829   }
1830   if (self.layout() == kSparseCsc) {
1831     return self.to_sparse_csr(dense_dim_opt).to_sparse_bsr(blocksize);
1832   }
1833 
1834   AT_ERROR("sparse_compressed_to_sparse_bsr: expected SparseCsr, SparseCsc, SparseBsr or SparseBsc layout but got ", self.layout());
1835   return Tensor{};
1836 }
1837 
sparse_compressed_to_sparse_bsc(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1838 Tensor sparse_compressed_to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1839   auto layout_to = kSparseBsc;
1840   TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse_bsc: unexpected same input and output layout");
1841   _to_sparse_check_arguments("sparse_compressed_to_sparse_bsc", self, layout_to, blocksize, dense_dim_opt);
1842 
1843   if (self.layout() == kSparseBsr) {
1844     return sparse_compressed_to_flipped(self, blocksize, "to_sparse_bsc");
1845   }
1846   if (self.layout() == kSparseCsc) {
1847     if (self.device() != kCPU) {
1848       TORCH_WARN("sparse_compressed_to_sparse_bsc executing on the CPU device, the performance may be sub-optimal");
1849     }
1850     return _compressed_to_block_compressed_cpu<kSparseBsc>(self.cpu(), blocksize).to(self.device());
1851   }
1852   if (self.layout() == kSparseCsr) {
1853     return self.to_sparse_csc(dense_dim_opt).to_sparse_bsc(blocksize);
1854   }
1855 
1856   AT_ERROR("sparse_compressed_to_sparse_bsc: expected SparseCsr, SparseCsc, SparseBsr or SparseBsc layout but got ", self.layout());
1857   return Tensor{};
1858 }
1859 
sparse_coo_to_sparse(const Tensor & self,const int64_t sparse_dim)1860 Tensor sparse_coo_to_sparse(const Tensor& self, const int64_t sparse_dim) {
1861   _to_sparse_check_arguments("sparse_coo_to_sparse", self, sparse_dim);
1862 
1863   AT_ERROR("sparse_coo_to_sparse: ", self.layout(), " to ", kSparse, " conversion not supported");
1864   return Tensor{};
1865 }
1866 
sparse_compressed_to_sparse(const Tensor & self,const int64_t sparse_dim)1867 Tensor sparse_compressed_to_sparse(const Tensor& self, const int64_t sparse_dim) {
1868   _to_sparse_check_arguments("sparse_compressed_to_sparse", self, sparse_dim);
1869 
1870   Layout layout = self.layout();
1871   auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(self);
1872   Tensor values;
1873   Tensor indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices,
1874                                                         false, (layout == kSparseCsc || layout == kSparseBsc));
1875   const auto batch_ndim = compressed_indices.dim() - 1;
1876   // Only CSR is trivially coalesced
1877   bool coalesced = layout == kSparseCsr || self.numel() == 0 || self._nnz() == 1;
1878   AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "sparse_compressed_to_sparse",
1879     [&] { values = self.values().flatten(0, batch_ndim); },
1880     [&] {
1881       auto blocksize = DimVector(self.values().sizes().slice(batch_ndim + 1, 2));
1882       DimVector batch_blocksize;
1883       batch_blocksize.append(batch_ndim, 1);
1884       batch_blocksize.append(blocksize);
1885       const auto block_coo_indices = at::zeros({batch_ndim + 2, blocksize[0] * blocksize[1]}, indices.options());
1886       block_coo_indices.narrow(0, batch_ndim, 2).copy_(at::sparse::full_coo_indices(blocksize, indices.options()));
1887       indices = indices
1888         // Scale indices that identify blocks to element-wise coordinates that correspond
1889         // to the top-left corner of each block.
1890         .mul(at::tensor(batch_blocksize, indices.options()).unsqueeze_(1))
1891         // Now that we know top-left block coordinates, we offset them with element-wise
1892         // coordinates in the block to get the result.
1893         // NOTE: indices is mapped from (dim, nnz) to (dim, nnz, 1),
1894         // and block_coo_indices is mapped from (dim, block_numel) to
1895         // (dim, 1, block_numel), so the result has shape
1896         // (dim, nnz, block_numel).
1897         .unsqueeze_(-1).add(block_coo_indices.unsqueeze_(1))
1898         // Squash the nnz and the block_numel dimension
1899         // to produce valid nnz dimension of a COO tensor.
1900         .flatten(-2, -1);
1901 
1902       values = self.values().flatten(0, batch_ndim + 2);
1903 
1904       // BSRs not spanning across several rows produces coalesced results.
1905       coalesced |= (layout == kSparseBsr && blocksize[0] == 1 && batch_ndim == 0);
1906     });
1907   return at::native::_sparse_coo_tensor_unsafe(indices, values, self.sizes())._coalesced_(coalesced);
1908 }
1909 
sparse_compressed_to_sparse(const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1910 Tensor sparse_compressed_to_sparse(const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1911   auto layout_to = layout.value_or(kSparse);
1912   TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_compressed_to_sparse: unexpected same input and output layout");
1913   _to_sparse_check_arguments("sparse_compressed_to_sparse", self, layout_to, blocksize, dense_dim_opt);
1914 
1915   auto blocksize_ = blocksize.value_or((self.layout() == kSparseBsr || self.layout() == kSparseBsc) ? at::sparse_csr::getBlockSize(self) : at::DimVector({1, 1}));
1916   switch (layout_to) {
1917   case kStrided:
1918     return sparse_compressed_to_dense(self, /*dtype=*/std::nullopt, /*masked_grad=*/std::nullopt);
1919   case kSparse:
1920     return sparse_compressed_to_sparse(self, 2);
1921   case kSparseCsr:
1922     return sparse_compressed_to_sparse_csr(self, dense_dim_opt);
1923   case kSparseCsc:
1924     return sparse_compressed_to_sparse_csc(self, dense_dim_opt);
1925   case kSparseBsr:
1926     return sparse_compressed_to_sparse_bsr(self, blocksize_, dense_dim_opt);
1927   case kSparseBsc:
1928     return sparse_compressed_to_sparse_bsc(self, blocksize_, dense_dim_opt);
1929   default:
1930     break;
1931   }
1932 
1933   AT_ERROR("sparse_compressed_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported");
1934   return Tensor{};
1935 }
1936 
sparse_coo_to_sparse(const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1937 Tensor sparse_coo_to_sparse(const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1938   auto layout_to = layout.value_or(kSparse);
1939   TORCH_INTERNAL_ASSERT(self.layout() != layout_to, "sparse_coo_to_sparse: unexpected same input and output layout");
1940   _to_sparse_check_arguments("sparse_coo_to_sparse", self, layout_to, blocksize, dense_dim_opt);
1941 
1942   switch (layout_to) {
1943   case kStrided:
1944     return self.to_dense(std::nullopt, std::nullopt);
1945   case kSparseCsr:
1946     return self.to_sparse_csr(dense_dim_opt);
1947   case kSparseCsc:
1948     return self.to_sparse_csc(dense_dim_opt);
1949   case kSparseBsr:
1950     return self.to_sparse_bsr(*blocksize, dense_dim_opt);
1951   case kSparseBsc:
1952     return self.to_sparse_bsc(*blocksize, dense_dim_opt);
1953   default:
1954     break;
1955   }
1956 
1957   AT_ERROR("sparse_coo_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported");
1958   return Tensor{};
1959 }
1960 
to_sparse(const Tensor & self,const int64_t sparse_dim)1961 Tensor to_sparse(const Tensor& self, const int64_t sparse_dim) {
1962   auto layout_to = kSparse;
1963   if (self.layout() == layout_to) {
1964     _to_sparse_check_arguments("to_sparse", self, sparse_dim);
1965     return self;
1966   }
1967   return self._to_sparse(sparse_dim);
1968 }
1969 
to_sparse(const Tensor & self,std::optional<c10::Layout> layout,OptionalIntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1970 Tensor to_sparse(const Tensor& self, std::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1971   auto layout_to = layout.value_or(kSparse);
1972   if (self.layout() == layout_to) {
1973     _to_sparse_check_arguments("to_sparse", self, layout, blocksize, dense_dim_opt);
1974     return self;
1975   }
1976   return self._to_sparse(layout, blocksize, dense_dim_opt);
1977 }
1978 
to_sparse_csr(const Tensor & self,std::optional<int64_t> dense_dim_opt)1979 Tensor to_sparse_csr(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1980   auto layout_to = kSparseCsr;
1981   if (self.layout() == layout_to) {
1982     _to_sparse_check_arguments("to_sparse_csr", self, layout_to, {}, dense_dim_opt);
1983     return self;
1984   }
1985   return self._to_sparse_csr(dense_dim_opt);
1986 }
1987 
to_sparse_csc(const Tensor & self,std::optional<int64_t> dense_dim_opt)1988 Tensor to_sparse_csc(const Tensor& self, std::optional<int64_t> dense_dim_opt) {
1989   auto layout_to = kSparseCsc;
1990   if (self.layout() == layout_to) {
1991     _to_sparse_check_arguments("to_sparse_csc", self, layout_to, {}, dense_dim_opt);
1992     return self;
1993   }
1994   return self._to_sparse_csc(dense_dim_opt);
1995 }
1996 
to_sparse_bsr(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)1997 Tensor to_sparse_bsr(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
1998   auto layout_to = kSparseBsr;
1999   if (self.layout() == layout_to) {
2000     _to_sparse_check_arguments("to_sparse_bsr", self, layout_to, blocksize, dense_dim_opt);
2001     return self;
2002   }
2003   return self._to_sparse_bsr(blocksize, dense_dim_opt);
2004 }
2005 
to_sparse_bsc(const Tensor & self,IntArrayRef blocksize,std::optional<int64_t> dense_dim_opt)2006 Tensor to_sparse_bsc(const Tensor& self, IntArrayRef blocksize, std::optional<int64_t> dense_dim_opt) {
2007   auto layout_to = kSparseBsc;
2008   if (self.layout() == layout_to) {
2009     _to_sparse_check_arguments("to_sparse_bsc", self, layout_to, blocksize, dense_dim_opt);
2010     return self;
2011   }
2012   return self._to_sparse_bsc(blocksize, dense_dim_opt);
2013 }
2014 
2015 // Sparse layout conversions End
2016 
to_meta(const Tensor & tensor)2017 Tensor to_meta(const Tensor& tensor) {
2018   auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
2019 /*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \
2020 /*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt);
2021   // needs to handle wrapped numbers, so dtype promotion works properly.
2022   if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
2023     out.unsafeGetTensorImpl()->set_wrapped_number(true);
2024   }
2025   return out;
2026 }
to_meta(const std::optional<Tensor> & tensor)2027 std::optional<Tensor> to_meta(const std::optional<Tensor>& tensor) {
2028   if (tensor.has_value()) {
2029     return to_meta(*tensor);
2030   }
2031   return std::nullopt;
2032 }
2033 
to_meta(at::ITensorListRef t_list)2034 std::vector<Tensor> to_meta(at::ITensorListRef t_list) {
2035   std::vector<Tensor> outs;
2036   outs.reserve(t_list.size());
2037   for (const auto& tensor : t_list) {
2038     outs.push_back(to_meta(tensor));
2039   }
2040   return outs;
2041 }
2042 } // namespace at::native
2043