xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseTensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Basic functions on sparse tensors
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/InitialTensorOptions.h>
7 #include <ATen/Layout.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/SparseCsrTensorUtils.h>
10 #include <ATen/SparseTensorImpl.h>
11 #include <ATen/native/SparseTensorUtils.h>
12 #include <ATen/native/sparse/SparseStubs.h>
13 #include <ATen/native/IndexingUtils.h>
14 #include <ATen/native/NonSymbolicBC.h>
15 #include <ATen/NamedTensorUtils.h>
16 
17 #include <ATen/native/Copy.h>
18 #include <ATen/native/CPUBlas.h>
19 #include <c10/util/irange.h>
20 
21 #ifndef AT_PER_OPERATOR_HEADERS
22 #include <ATen/Functions.h>
23 #include <ATen/NativeFunctions.h>
24 #else
25 #include <ATen/ops/_coalesce.h>
26 #include <ATen/ops/_coalesce_native.h>
27 #include <ATen/ops/_coalesced_native.h>
28 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
29 #include <ATen/ops/_dimI_native.h>
30 #include <ATen/ops/_dimV_native.h>
31 #include <ATen/ops/_indices_native.h>
32 #include <ATen/ops/_nnz_native.h>
33 #include <ATen/ops/_pin_memory_native.h>
34 #include <ATen/ops/sparse_coo_tensor.h>
35 #include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
36 #include <ATen/ops/_sparse_coo_tensor_with_dims.h>
37 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
38 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_native.h>
39 #include <ATen/ops/_sparse_coo_tensor_with_dims_native.h>
40 #include <ATen/ops/_validate_sparse_coo_tensor_args_native.h>
41 #include <ATen/ops/_values_native.h>
42 #include <ATen/ops/clone_native.h>
43 #include <ATen/ops/coalesce_native.h>
44 #include <ATen/ops/copy_native.h>
45 #include <ATen/ops/copy_sparse_to_sparse.h>
46 #include <ATen/ops/copy_sparse_to_sparse_native.h>
47 #include <ATen/ops/dense_dim_native.h>
48 #include <ATen/ops/empty.h>
49 #include <ATen/ops/empty_like_native.h>
50 #include <ATen/ops/empty_native.h>
51 #include <ATen/ops/zeros_like.h>
52 #include <ATen/ops/index_select.h>
53 #include <ATen/ops/indices_native.h>
54 #include <ATen/ops/is_coalesced_native.h>
55 #include <ATen/ops/is_pinned_native.h>
56 #include <ATen/ops/resize_as_sparse.h>
57 #include <ATen/ops/resize_as_sparse_native.h>
58 #include <ATen/ops/sparse_coo_tensor.h>
59 #include <ATen/ops/sparse_coo_tensor_native.h>
60 #include <ATen/ops/sparse_dim_native.h>
61 #include <ATen/ops/sparse_mask_native.h>
62 #include <ATen/ops/_sparse_mask_projection_native.h>
63 #include <ATen/ops/sparse_resize_and_clear_native.h>
64 #include <ATen/ops/sparse_resize_native.h>
65 #include <ATen/ops/to_dense_native.h>
66 #include <ATen/ops/to_sparse_native.h>
67 #include <ATen/ops/unique_dim.h>
68 #include <ATen/ops/values_native.h>
69 #include <ATen/ops/zeros.h>
70 #include <ATen/ops/ones.h>
71 #endif
72 
73 namespace at::native {
74 
75 using namespace at::sparse;
76 
77 /******************************************************************************
78  * access methods
79  ******************************************************************************/
80 
sparse_dim_sparse(const SparseTensor & self)81 int64_t sparse_dim_sparse(const SparseTensor& self) {
82   return get_sparse_impl(self)->sparse_dim();
83 }
84 
dense_dim_sparse(const SparseTensor & self)85 int64_t dense_dim_sparse(const SparseTensor& self) {
86   return get_sparse_impl(self)->dense_dim();
87 }
88 
is_coalesced_sparse(const SparseTensor & self)89 bool is_coalesced_sparse(const SparseTensor& self) {
90   return get_sparse_impl(self)->coalesced();
91 }
92 
is_coalesced_default(const Tensor & self)93 bool is_coalesced_default(const Tensor& self) {
94   TORCH_CHECK(false, "is_coalesced expected sparse coordinate tensor layout but got ", self.layout());
95   return false;
96 }
97 
_nnz_sparse(const SparseTensor & self)98 int64_t _nnz_sparse(const SparseTensor& self) {
99   return get_sparse_impl(self)->nnz();
100 }
101 
102 // Why are there so many methods to get indices and value?
103 // See Note [ Sparse: different methods to get indices and values ] in
104 // native_functions.yaml
105 
_indices_sparse(const SparseTensor & self)106 Tensor _indices_sparse(const SparseTensor& self) {
107   return get_sparse_impl(self)->indices();
108 }
109 
_values_sparse(const SparseTensor & self)110 Tensor _values_sparse(const SparseTensor& self) {
111   return get_sparse_impl(self)->values();
112 }
113 
_coalesced_sparse_(SparseTensor & self,bool coalesced)114 Tensor& _coalesced_sparse_(SparseTensor& self, bool coalesced) {
115   get_sparse_impl(self)->set_coalesced(coalesced);
116   return self;
117 }
118 
indices_sparse(const Tensor & self)119 Tensor indices_sparse(const Tensor& self) {
120   TORCH_CHECK(
121       self.is_coalesced(),
122       "Cannot get indices on an uncoalesced tensor, please call .coalesce() first");
123   return get_sparse_impl(self)->indices().alias();
124 }
125 
indices_default(const Tensor & self)126 Tensor indices_default(const Tensor& self) {
127   TORCH_CHECK(false, "indices expected sparse coordinate tensor layout but got ", self.layout());
128 }
129 
values_sparse(const Tensor & self)130 Tensor values_sparse(const Tensor& self) {
131   TORCH_CHECK(
132       self.is_coalesced(),
133       "Cannot get values on an uncoalesced tensor, please call .coalesce() first");
134   return get_sparse_impl(self)->values().alias();
135 }
136 
values_default(const Tensor & self)137 Tensor values_default(const Tensor& self) {
138   TORCH_CHECK(false, "values expected sparse tensor layout but got ", self.layout());
139 }
140 
141 /******************************************************************************
142  * creation methods
143  * See NOTE [ Sparse: autograd and API ] for details
144  ******************************************************************************/
145 
146 /*** Helper methods ***/
147 
new_sparse(std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)148 static SparseTensor new_sparse(
149     std::optional<ScalarType> dtype,
150     std::optional<Layout> layout,
151     std::optional<Device> device,
152     std::optional<bool> pin_memory) {
153   AT_ASSERT(layout.has_value() && *layout == kSparse);
154   DispatchKey dispatch_key;
155   switch (device_or_default(device).type()) {
156 #define DO_CASE(device, _) \
157     case DeviceType::device: \
158       dispatch_key = DispatchKey::Sparse##device; \
159       break;
160     C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
161 #undef DO_CASE
162     default:
163       TORCH_CHECK(false, "device type not supported for sparse ", device_or_default(device))
164   }
165   return detail::make_tensor<SparseTensorImpl>(
166       DispatchKeySet(dispatch_key),
167       scalarTypeToTypeMeta(dtype_or_default(dtype)));
168 }
169 
170 /** Actual dispatched creation methods ***/
171 
new_with_dims_sparse(int64_t sparse_dim,int64_t dense_dim,ArrayRef<int64_t> size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)172 SparseTensor new_with_dims_sparse(
173     int64_t sparse_dim,
174     int64_t dense_dim,
175     ArrayRef<int64_t> size,
176     std::optional<ScalarType> dtype,
177     std::optional<Layout> layout,
178     std::optional<Device> device,
179     std::optional<bool> pin_memory) {
180   SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
181   get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size);
182   return self;
183 }
184 
new_with_dims_and_tensor_sparse_symint(int64_t sparse_dim,int64_t dense_dim,c10::SymIntArrayRef size,const Tensor & indices,const Tensor & values,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<bool> is_coalesced)185 SparseTensor new_with_dims_and_tensor_sparse_symint(
186     int64_t sparse_dim,
187     int64_t dense_dim,
188     c10::SymIntArrayRef size,
189     const Tensor& indices,
190     const Tensor& values,
191     std::optional<ScalarType> dtype,
192     std::optional<Layout> layout,
193     std::optional<Device> device,
194     std::optional<bool> pin_memory,
195     std::optional<bool> is_coalesced) {
196   SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
197   auto impl = get_sparse_impl(self);
198   impl->resize_(sparse_dim, dense_dim, size);
199   // NOTE: There is no guarantee that `indices` and `values` don't contain
200   // AutogradMeta. However, we want to maintain the invariant that `indices_`
201   // and `values_` of a sparse tensor don't contain AutogradMeta, and to achieve
202   // that we shallow-copy `indices` and `values` here.
203   auto indices_shallow_copy =
204       Tensor(indices.unsafeGetTensorImpl()->shallow_copy_and_detach(
205           /*version_counter=*/indices.unsafeGetTensorImpl()->version_counter(),
206           /*allow_tensor_metadata_change=*/true));
207   auto values_shallow_copy =
208       Tensor(values.unsafeGetTensorImpl()->shallow_copy_and_detach(
209           /*version_counter=*/values.unsafeGetTensorImpl()->version_counter(),
210           /*allow_tensor_metadata_change=*/true));
211   if (pin_memory.value_or(false)) {
212     alias_into_sparse(self, indices_shallow_copy.pin_memory(), values_shallow_copy.pin_memory());
213   } else {
214     alias_into_sparse(self, indices_shallow_copy, values_shallow_copy);
215   }
216   // alias_into_sparse overrides coalesced flag, so resetting the flag to
217   // the desired state here:
218   if (is_coalesced.has_value()) {
219     impl->set_coalesced(*is_coalesced);
220   }
221   // TODO: alias_into_sparse sets the coalesce flag to
222   // `self._values().shape[0] < 2`. There exist methods (e.g. permute
223   // on COO tensors when `dims[0] != 0` holds) that force coalesced
224   // flag to false even when nnz is less that 2. Here we cannot
225   // determine if this is the intention of such methods but it is
226   // likely that these methods are overly restrictive when estimating
227   // is_coalesced state. The condition `!is_coalesced && self._nnz() <
228   // 2` provides a way to detect and optimize such methods with
229   // respect to estimating the is_coalesced state.
230   return self;
231 }
232 
233 /** Public creation API that dispatch to methods above **/
234 
235 /** Empty init **/
empty_sparse_symint(SymIntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<MemoryFormat> optional_memory_format)236 Tensor empty_sparse_symint(
237     SymIntArrayRef size,
238     std::optional<ScalarType> dtype,
239     std::optional<Layout> layout,
240     std::optional<Device> device,
241     std::optional<bool> pin_memory,
242     std::optional<MemoryFormat> optional_memory_format) {
243   // TODO: Don't specialize
244   return empty_sparse(C10_AS_INTARRAYREF_SLOW_ALLOC(size), dtype, layout, device, pin_memory, optional_memory_format);
245 }
246 
empty_sparse(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<MemoryFormat> optional_memory_format)247 Tensor empty_sparse(
248     IntArrayRef size,
249     std::optional<ScalarType> dtype,
250     std::optional<Layout> layout,
251     std::optional<Device> device,
252     std::optional<bool> pin_memory,
253     std::optional<MemoryFormat> optional_memory_format) {
254   TORCH_CHECK(
255       !pin_memory.has_value() || !*pin_memory,
256       "Only dense CPU tensors can be pinned");
257   return new_with_dims_sparse(
258       size.size(), 0, size, dtype, layout, device, pin_memory);
259 }
260 
261 /* Shape init */
sparse_coo_tensor(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)262 Tensor sparse_coo_tensor(IntArrayRef size,
263     std::optional<ScalarType> dtype,
264     std::optional<Layout> layout,
265     std::optional<Device> device,
266     std::optional<bool> pin_memory) {
267   // See [Note: hacky wrapper removal for TensorOptions]
268   TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
269 
270   return at::_sparse_coo_tensor_with_dims(size.size(), 0, size, options.layout(at::kSparse));
271 }
272 
273 /* Pointer-copy init */
274 
275 // helper
276 namespace {
expand_values_if_needed(const Tensor & values)277 static inline Tensor expand_values_if_needed(const Tensor& values) {
278   // expand
279   if (values.dim() == 0) {
280     // Mimic Numpy behavior here and treat it as a 1D tensor
281     return values.expand({1});
282   } else {
283     return values;
284   }
285 }
286 } // namespace
287 
sparse_coo_tensor(const Tensor & indices,const Tensor & values_,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<bool> is_coalesced)288 Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values_,
289     std::optional<ScalarType> dtype,
290     std::optional<Layout> layout,
291     std::optional<Device> device,
292     std::optional<bool> pin_memory,
293     std::optional<bool> is_coalesced) {
294   // See [Note: hacky wrapper removal for TensorOptions]
295   TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
296 
297   Tensor values = expand_values_if_needed(values_);
298 
299   // arg checking
300   TORCH_CHECK(
301       !options.has_layout() || options.layout() == kSparse,
302       "expected sparse layout, but got layout ",
303       options.layout());
304   // the following checks are redundant because they are also checked in
305   // SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
306   // in order to infer the shape.
307   TORCH_CHECK(
308       indices.dim() == 2,
309       "indices must be sparse_dim x nnz, but got: ",
310       indices.sizes())
311   TORCH_CHECK(
312       !indices.is_sparse(),
313       "expected indices to be a dense tensor, but got indices of layout ",
314       indices.layout());
315 
316   // If sizes are not given, it is inferred as max index of each dim.
317   int64_t sparse_dim = indices.size(0);
318   int64_t dense_dim = values.dim() - 1;
319 
320   std::vector<int64_t> computed_sizes(sparse_dim + dense_dim);
321   if (indices.numel() > 0) {
322     // If the indices has elements in it, we infer the minimum sparse dimension
323     // sizes as the max value of each dim in indices. NB: It used to keepdim. I
324     // think that was wrong.
325     Tensor min_indices =
326         std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
327     Tensor computed_indices_sizes =
328         std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
329     computed_indices_sizes.add_(1); // len = max_index + 1
330     Tensor cpu_min_indices = min_indices.to(at::DeviceType::CPU);
331     Tensor cpu_computed_indices_sizes =
332         computed_indices_sizes.to(at::DeviceType::CPU);
333     auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
334     auto cpu_computed_indices_sizes_accessor =
335         cpu_computed_indices_sizes.accessor<int64_t, 1>();
336     for (const auto d : c10::irange(sparse_dim)) {
337       int64_t min_index_in_dim = cpu_min_indices_accessor[d];
338       TORCH_CHECK(
339           min_index_in_dim >= 0,
340           "found negative index ",
341           min_index_in_dim,
342           " for dim ",
343           d);
344       computed_sizes[static_cast<size_t>(d)] =
345           cpu_computed_indices_sizes_accessor[d];
346     }
347   } else {
348     // If the indices doesn't have elements in it, there is not enough
349     // information to know what the minimum sparse dimension sizes should be,
350     // and in this case we set them to 0
351     for (const auto d : c10::irange(sparse_dim)) {
352       computed_sizes[static_cast<size_t>(d)] = 0;
353     }
354   }
355   for (const auto d : c10::irange(dense_dim)) {
356     computed_sizes[static_cast<size_t>(sparse_dim + d)] = values.size(d + 1);
357   }
358 
359   return at::_sparse_coo_tensor_with_dims_and_tensors(
360       sparse_dim,
361       dense_dim,
362       computed_sizes,
363       indices,
364       values,
365       values.options().layout(kSparse),
366       is_coalesced);
367 }
368 
_validate_sparse_coo_tensor_args(const Tensor & indices,const Tensor & values_,ArrayRef<int64_t> size,std::optional<bool> is_coalesced_)369 void _validate_sparse_coo_tensor_args(
370     const Tensor& indices,
371     const Tensor& values_,
372     ArrayRef<int64_t> size,
373     std::optional<bool> is_coalesced_) {
374   Tensor values = expand_values_if_needed(values_);
375   bool is_coalesced = is_coalesced_.value_or(false);
376 
377   // the following checks are redundant because they are also checked in
378   // SparseTensorImpl::set_indices_and_values_unsafe but we need to ensure them
379   // in order to infer the shape.
380   TORCH_CHECK(
381       indices.dim() == 2,
382       "indices must be sparse_dim x nnz, but got: ",
383       indices.sizes())
384   TORCH_CHECK(
385       !indices.is_sparse(),
386       "expected indices to be a dense tensor, but got indices of layout ",
387       indices.layout());
388   int64_t sparse_dim = indices.size(0);
389   int64_t dense_dim = values.dim() - 1;
390   TORCH_CHECK(
391       static_cast<int64_t>(size.size()) == sparse_dim + dense_dim,
392       "number of dimensions must be sparse_dim (",
393       sparse_dim,
394       ") + dense_dim (",
395       dense_dim,
396       "), but got ",
397       size.size());
398 
399   TORCH_CHECK(
400       indices.is_pinned() == values.is_pinned(),
401       "memory pinning of indices (=",
402       indices.is_pinned(),
403       ") must match memory pinning of values (=",
404       values.is_pinned(),
405       ")");
406 
407   // Check to make sure all indices are within the boundaries of `size`
408   if (indices.numel() > 0) {
409     Tensor min_indices =
410         std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
411     Tensor max_indices =
412         std::get</* values */ 0>(indices.max(/* dim */ 1, /* keepdim */ false));
413     Tensor cpu_min_indices, cpu_max_indices;
414     if (!indices.is_cpu()) {
415       cpu_min_indices = min_indices.to(at::DeviceType::CPU);
416       cpu_max_indices = max_indices.to(at::DeviceType::CPU);
417     } else {
418       cpu_min_indices = min_indices;
419       cpu_max_indices = max_indices;
420     }
421     auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
422     auto cpu_max_indices_accessor = cpu_max_indices.accessor<int64_t, 1>();
423     for (const auto d : c10::irange(sparse_dim)) {
424       // NB: This used to sync ndim times to access each entry; now we copy
425       // everything to CPU first and then access it.
426       int64_t min_index_in_dim = cpu_min_indices_accessor[d];
427       TORCH_CHECK(
428           min_index_in_dim >= 0,
429           "found negative index ",
430           min_index_in_dim,
431           " for dim ",
432           d);
433       int64_t max_index_in_dim = cpu_max_indices_accessor[d];
434       int64_t dim_size = size[static_cast<size_t>(d)];
435       TORCH_CHECK(
436           max_index_in_dim < dim_size,
437           "size is inconsistent with indices: for dim ",
438           d,
439           ", size is ",
440           dim_size,
441           " but found index ",
442           max_index_in_dim);
443     }
444     if (is_coalesced && values.size(0) > 1) {
445       Tensor indices_scalar = flatten_indices(indices, size);
446       Tensor diff = indices_scalar.diff();
447       TORCH_CHECK(diff.min().item().toLong() > 0, "cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor");
448     }
449   }
450 }
451 
452 // NB: Got rid of the sizes == NULL case
sparse_coo_tensor(const Tensor & indices,const Tensor & values,IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<bool> is_coalesced)453 Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRef size,
454     std::optional<ScalarType> dtype,
455     std::optional<Layout> layout,
456     std::optional<Device> device,
457     std::optional<bool> pin_memory,
458     std::optional<bool> is_coalesced) {
459   // See [Note: hacky wrapper removal for TensorOptions]
460   TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
461   // arg checking
462   TORCH_CHECK(
463       !options.has_layout() || options.layout() == kSparse,
464       "expected sparse layout, but got layout ",
465       options.layout());
466   return at::native::_sparse_coo_tensor_unsafe(
467       indices,
468       values,
469       size,
470       optTypeMetaToScalarType(options.dtype_opt()),
471       options.layout_opt(),
472       options.device_opt(),
473       options.pinned_memory_opt(),
474       is_coalesced);
475 }
476 
_sparse_coo_tensor_unsafe(const Tensor & indices,const Tensor & values_,at::IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<bool> is_coalesced)477 Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values_, at::IntArrayRef size,
478     std::optional<ScalarType> dtype,
479     std::optional<Layout> layout,
480     std::optional<Device> device,
481     std::optional<bool> pin_memory,
482     std::optional<bool> is_coalesced) {
483   if (at::globalContext().checkSparseTensorInvariants()) {
484     at::native::_validate_sparse_coo_tensor_args(indices, values_, size, is_coalesced);
485   }
486   return at::native::_sparse_coo_tensor_unsafe_symint(indices, values_, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, is_coalesced);
487 }
488 
489 // NOTE: _sparse_coo_tensor_unsafe() differs from sparse_coo_tensor()
490 // in that we don't check whether any indices are out of boundaries of `size`, thus avoiding a
491 // copy from CUDA to CPU. However, this function should ONLY be used where we know that the indices
492 // are guaranteed to be within bounds or if the caller is going to call
493 // _validate_sparse_coo_tensor_args before using the tensor.
494 // NB: Got rid of the size == NULL case
_sparse_coo_tensor_unsafe_symint(const Tensor & indices,const Tensor & values_,c10::SymIntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<bool> is_coalesced)495 Tensor _sparse_coo_tensor_unsafe_symint(const Tensor& indices, const Tensor& values_, c10::SymIntArrayRef size,
496     std::optional<ScalarType> dtype,
497     std::optional<Layout> layout,
498     std::optional<Device> device,
499     std::optional<bool> pin_memory,
500     std::optional<bool> is_coalesced) {
501   // See [Note: hacky wrapper removal for TensorOptions]
502 
503   Tensor values = expand_values_if_needed(values_);
504 
505   // This guard is intentional: we don't support dynamic shapes along the
506   // indices dimension because that implies variable dimensionality
507   auto sparse_dim = indices.sym_size(0).guard_int(__FILE__, __LINE__);
508   auto dense_dim = values.dim() - 1;
509   return at::_sparse_coo_tensor_with_dims_and_tensors_symint(
510       sparse_dim,
511       dense_dim,
512       size,
513       indices,
514       values,
515       values.options().layout(kSparse).pinned_memory(pin_memory),
516       is_coalesced);
517 }
518 
519 // NB: Deleted newWithSizeNd variants
520 
clone_sparse(const SparseTensor & self,std::optional<c10::MemoryFormat> optional_memory_format)521 SparseTensor clone_sparse(
522     const SparseTensor& self,
523     std::optional<c10::MemoryFormat> optional_memory_format) {
524   TORCH_CHECK(
525       !optional_memory_format.has_value(),
526       "unsupported memory format option ",
527       optional_memory_format.value());
528   SparseTensor other = new_with_dims_sparse(
529       self.sparse_dim(),
530       self.dense_dim(),
531       self.sizes(),
532       optTypeMetaToScalarType(self.options().dtype_opt()),
533       self.options().layout_opt(),
534       self.options().device_opt(),
535       self.options().pinned_memory_opt());
536   copy_into_sparse(other, self._indices(), self._values(), true);
537   return other._coalesced_(self.is_coalesced());
538 }
539 
540 /******************************************************************************
541  * reshaping methods
542  ******************************************************************************/
543 
sparse_resize_(const SparseTensor & self,ArrayRef<int64_t> size,int64_t sparse_dim,int64_t dense_dim)544 const SparseTensor& sparse_resize_(
545     const SparseTensor& self,
546     ArrayRef<int64_t> size,
547     int64_t sparse_dim,
548     int64_t dense_dim) {
549   get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size);
550   return self;
551 }
552 
sparse_resize_and_clear_(const SparseTensor & self,ArrayRef<int64_t> size,int64_t sparse_dim,int64_t dense_dim)553 const SparseTensor& sparse_resize_and_clear_(
554     const SparseTensor& self,
555     ArrayRef<int64_t> size,
556     int64_t sparse_dim,
557     int64_t dense_dim) {
558   get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size);
559   return self;
560 }
561 
562 namespace {
_is_same_size_as_sparse(const SparseTensor & self,const SparseTensor & src)563 bool _is_same_size_as_sparse(
564     const SparseTensor& self,
565     const SparseTensor& src) {
566   return self.sparse_dim() == src.sparse_dim() &&
567       self.dense_dim() == src.dense_dim() && self.sizes().equals(src.sizes());
568 }
569 } // namespace
570 
571 // Invoked from native/Resize.cpp (no dynamic dispatch necessary)
resize_as_sparse_(const SparseTensor & self,const SparseTensor & src)572 const SparseTensor& resize_as_sparse_(const SparseTensor& self, const SparseTensor& src) {
573   if (!_is_same_size_as_sparse(self, src)) {
574     sparse_resize_(self, src.sizes(), src.sparse_dim(), src.dense_dim());
575   }
576   return self;
577 }
578 
579 // NB: Dropped the resizeNd variants
580 
copy_sparse_wrapper_(Tensor & self,const Tensor & src,bool non_blocking)581 SparseTensor& copy_sparse_wrapper_(
582     Tensor& self,
583     const Tensor& src,
584     bool non_blocking) {
585   // TODO: Once copy_ is fully migrated to use dispatcher, handle named
586   // inference using dispatcher instead of doing it everywhere
587   auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
588   {
589     NoNamesGuard guard;
590     if (!self.is_sparse() || !src.is_sparse()) {
591       AT_ERROR(
592           "copy_() between dense and sparse Tensors is not implemented! Found self type = ",
593           self.toString(),
594           " and src type = ",
595           src.toString());
596     }
597     at::copy_sparse_to_sparse_(self, src, non_blocking);
598   }
599   namedinference::propagate_names_if_nonempty(self, maybe_outnames);
600   return self;
601 }
602 
copy_sparse_(SparseTensor & self,const SparseTensor & src,bool non_blocking)603 SparseTensor& copy_sparse_(
604     SparseTensor& self,
605     const SparseTensor& src,
606     bool non_blocking) {
607   if (is_same_tensor(self, src))
608     return self;
609   get_sparse_impl(self)->resize_(
610       src.sparse_dim(), src.dense_dim(), src.sizes());
611   copy_into_sparse(self, src._indices(), src._values(), non_blocking);
612   return self._coalesced_(src.is_coalesced());
613 }
614 
coalesce(const SparseTensor & self)615 SparseTensor coalesce(const SparseTensor& self) {
616   TORCH_CHECK(self.layout() == kSparse, "coalesce expected sparse coordinate tensor layout but got ", self.layout());
617   // See NOTE: [ coalesce autograd ]
618   if (self.is_coalesced()) {
619     return self;
620   }
621   return at::_coalesce(self);
622 }
623 
_coalesce_sparse_cpu(const SparseTensor & self)624 SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) {
625   AT_ASSERT(self.defined());
626   TORCH_INTERNAL_ASSERT(at::impl::variable_excluded_from_dispatch());
627   AT_ASSERT(self.is_sparse());
628   TORCH_INTERNAL_ASSERT(!self.is_coalesced());
629 
630   // NOTE: Since `coalesce` is not an in-place operation when `is_coalesced` is false,
631   // we should keep the original tensor intact and do coalesce on a copy of the tensor
632   if (self._nnz() < 2) {
633     SparseTensor dst = self.clone();
634     dst._coalesced_(true);
635     return dst;
636   }
637 
638   Tensor indices = self._indices();
639   Tensor values = self._values().contiguous();
640   int64_t sparse_dim = self.sparse_dim();
641   int64_t dense_dim = self.dense_dim();
642   int64_t nnz = self._nnz();
643 
644   Tensor indices_scalar = flatten_indices(indices, self.sizes());
645 
646   SparseTensor dst = new_sparse(
647       optTypeMetaToScalarType(self.options().dtype_opt()),
648       self.options().layout_opt(),
649       self.options().device_opt(),
650       self.options().pinned_memory_opt());
651   get_sparse_impl(dst)->resize_(sparse_dim, dense_dim, self.sizes());
652   // TODO: is there a more idiomatic way to do this?
653   Tensor newIndices = at::empty(indices.sizes(), indices.options());
654   Tensor newValues = at::empty(values.sizes(), values.options());
655   alias_into_sparse(dst, newIndices, newValues);
656 
657   auto [indicesBuffer, indicesPermutation] = indices_scalar.sort(0);
658   // NB: The accessor accesses here rely on self._nnz() > 0 (tested earlier in
659   // this function)
660   auto newIndicesAccessor = newIndices.accessor<int64_t, 2>();
661   auto indicesAccessor = indices.accessor<int64_t, 2>();
662   auto indicesPermutationAccessor = indicesPermutation.accessor<int64_t, 1>();
663   auto indicesBufferAccessor = indicesBuffer.accessor<int64_t, 1>();
664 
665   int64_t i = -1;
666   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
667       at::ScalarType::ComplexHalf, at::ScalarType::BFloat16, at::ScalarType::Half, at::ScalarType::Bool,
668       values.scalar_type(), "coalesce", [&] {
669     int64_t prev = -1;
670     int64_t blockSize = values.stride(0);
671     scalar_t* values_ptr = values.data_ptr<scalar_t>();
672     scalar_t* newValues_ptr = newValues.data_ptr<scalar_t>();
673     for (const auto j : c10::irange(nnz)) {
674       int64_t pos = indicesPermutationAccessor[j];
675       int64_t curr = indicesBufferAccessor[j];
676       if (curr == prev) {
677         if (values.numel() >
678             0) { // if values is an empty tensor, there are no elements to copy
679           at::native::cpublas::axpy<scalar_t>(
680               blockSize,
681               static_cast<scalar_t>(1),
682               values_ptr + pos * blockSize,
683               1,
684               newValues_ptr + i * blockSize,
685               1);
686         }
687       } else {
688         ++i;
689         for (const auto d : c10::irange(sparse_dim)) {
690           newIndicesAccessor[d][i] = indicesAccessor[d][pos];
691         }
692         if (values.numel() >
693             0) { // if values is an empty tensor, there are no elements to copy
694           at::native::cpublas::copy<scalar_t>(
695               blockSize,
696               values_ptr + pos * blockSize,
697               1,
698               newValues_ptr + i * blockSize,
699               1);
700         }
701       }
702       prev = curr;
703     }
704   });
705 
706   dst._coalesced_(true);
707   get_sparse_impl(dst)->set_nnz_and_narrow(i + 1);
708 
709   return dst;
710 }
711 
712 DEFINE_DISPATCH(sparse_mask_intersection_out_stub);
713 DEFINE_DISPATCH(sparse_mask_projection_out_stub);
714 
715 using OptTensor = std::optional<Tensor>;
716 
sparse_mask_like_prepare_sparse_inputs(const std::string & method_name,const Tensor & t,const Tensor & mask)717 static std::tuple<Tensor, Tensor, OptTensor> sparse_mask_like_prepare_sparse_inputs(
718     const std::string& method_name,
719     const Tensor& t,
720     const Tensor& mask) {
721   // This is a helper function for operations that implement "sparse_mask"-like
722   // functionality, namely, projection of values of one tensor onto the other.
723   // These operations mostly rely on COO intersection primitives that heavily
724   // exploit coalesced inputs to avoid any syncs and calls to sort. The problem
725   // is that these primitives might project first argument onto second one or
726   // the other way around depending on which arguments are coalesced and which are
727   // larger. This function prepares inputs for `sparse_mask` such that `t` is
728   // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it
729   // as coalesced all while `mask` is set to uncoalesced.
730   // The result of this projectionk is going to be uncoalesced, so it is up to the
731   // user to set the corresponding flag correctly with respect to the operations'
732   // semantics.
733 
734   // We already assume that t.sizes() == mask.sizes()
735   TORCH_CHECK(t.sparse_dim() == mask.sparse_dim(),
736               method_name, "(): the number of sparse dimensions in `self` ",
737               "should match that of the `mask`. ",
738               "Got `self.sparse_dim() == ", t.sparse_dim(), "` != ",
739               "`mask.sparse_dim() == ", mask.sparse_dim(), "`.");
740 
741   const auto wrapped_tensor = [](const Tensor& t,
742                                  const OptTensor& indices = std::nullopt,
743                                  const OptTensor& values = std::nullopt) -> Tensor {
744     auto res = at::empty({0}, t.options());
745     auto* res_sparse_impl = get_sparse_impl(res);
746     res_sparse_impl->raw_resize_(t.sparse_dim(), t.dense_dim(), t.sizes());
747     const auto res_indices = indices.has_value() ? *indices : t._indices();
748     const auto res_values = values.has_value() ? *values : t._values();
749     res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
750     res_sparse_impl->set_nnz_and_narrow(t._nnz());
751     res._coalesced_(false);
752     return res;
753   };
754 
755   auto [lhs, lhs_hash_opt, lhs_is_movable] = [&]() -> auto {
756     if (t.is_coalesced()) {
757       return std::make_tuple(t, static_cast<OptTensor>(std::nullopt), false);
758     } else {
759       const auto indices_hash = at::sparse::flatten_indices(t._indices(), t.sizes());
760       const auto argsort_indices_hash = std::get<1>(indices_hash.sort(0));
761       // Probably worth having a dedicated kernel for.
762       const auto res_indices = t._indices().index_select(1, argsort_indices_hash);
763       const auto res_values = t._values().index_select(0, argsort_indices_hash);
764       const auto indices_hash_sorted = indices_hash.index_select(0, argsort_indices_hash);
765       // NOTE: res is not necessarily coalesced, but it is sorted.
766       // We mark it as "coalesced" to skip sorting in the intersection kernel.
767       auto res = wrapped_tensor(t, res_indices, res_values)._coalesced_(true);
768       return std::make_tuple(std::move(res), static_cast<OptTensor>(std::move(indices_hash_sorted)), true);
769     }
770   }();
771 
772   const auto rhs = mask.is_coalesced() ? wrapped_tensor(mask) : mask;
773   const auto rhs_is_movable = mask.is_coalesced() ? true : false;
774 
775   return std::make_tuple(lhs_is_movable ? std::move(lhs) : lhs,
776                          rhs_is_movable ? std::move(rhs) : rhs,
777                          lhs_hash_opt);
778 }
779 
sparse_mask(const Tensor & t,const SparseTensor & mask)780 SparseTensor sparse_mask(const Tensor& t, const SparseTensor& mask) {
781   TORCH_CHECK(
782       mask.sizes().equals(t.sizes()),
783       "sparse_mask(): operands have incompatible sizes; self has size ",
784       t.sizes(),
785       " but mask has size ",
786       mask.sizes());
787 
788   if (t.is_same(mask)) {
789     return t;
790   }
791 
792   if (!mask.numel() || !mask._nnz()) {
793     return mask.clone().to(t.device(), t.scalar_type());
794   }
795 
796   if (t.layout() == at::kSparse) {
797     if (!t._nnz()) {
798       auto res = mask.clone().to(t.device(), t.scalar_type());
799       res._values().zero_();
800       return res;
801     }
802 
803     auto res = at::empty({0}, t.options());
804     auto [lhs, rhs, lhs_hash_opt] = sparse_mask_like_prepare_sparse_inputs("sparse_mask", t, mask);
805     sparse_mask_intersection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt);
806     return res._coalesced_(mask.is_coalesced());
807   }
808 
809   const auto mask_values = mask._values();
810   auto mask_template = at::sparse_coo_tensor(
811       mask._indices(),
812       at::ones({1}, mask_values.options()).expand_as(mask_values),
813       mask.sizes())._coalesced_(mask.is_coalesced());
814   return t.mul(mask_template).to(t.scalar_type());
815 }
816 
sparse_mask_projection(const Tensor & t,const Tensor & mask,bool accumulate_matches)817 Tensor sparse_mask_projection(const Tensor& t, const Tensor& mask, bool accumulate_matches) {
818   TORCH_INTERNAL_ASSERT(t.is_sparse());
819   TORCH_INTERNAL_ASSERT(mask.is_sparse());
820 
821   TORCH_CHECK(
822       mask.sizes().equals(t.sizes()),
823       "_sparse_mask_projection(): operands have incompatible sizes; self has size ",
824       t.sizes(),
825       " but mask has size ",
826       mask.sizes());
827 
828   if (!t.numel() || !t._nnz() || !mask._nnz()) {
829     auto res = t.clone();
830     res._values().zero_();
831     return res;
832   }
833 
834   auto res = at::empty({0}, t.options());
835   auto [lhs, rhs, lhs_hash_opt] = sparse_mask_like_prepare_sparse_inputs("_sparse_mask_projection", mask, t);
836   sparse_mask_projection_out_stub(res.device().type(), res, lhs, rhs, lhs_hash_opt, accumulate_matches);
837   return res._coalesced_(t.is_coalesced());
838 }
839 
empty_like_sparse_coo(const Tensor & self,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)840 Tensor empty_like_sparse_coo(
841     const Tensor& self,
842     std::optional<ScalarType> dtype,
843     std::optional<Layout> layout,
844     std::optional<Device> device,
845     std::optional<bool> pin_memory,
846     std::optional<c10::MemoryFormat> optional_memory_format) {
847   TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
848 
849   TORCH_CHECK(
850     !(options_.has_memory_format() && optional_memory_format.has_value()),
851     "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
852     "the redundant setter.");
853 
854   TensorOptions options =
855       self.options()
856           .merge_in(options_)
857           .merge_memory_format(optional_memory_format);
858 
859   TORCH_CHECK(
860       !(options.layout() != kStrided &&
861           optional_memory_format.has_value()),
862       "memory format option is only supported by strided tensors");
863 
864   if (options.layout() == kSparse) {
865     auto result = at::empty({0}, options);
866     result.sparse_resize_and_clear_(
867         self.sizes(), self.sparse_dim(), self.dense_dim());
868     return result;
869   } else {
870     return at::native::empty_like(self, dtype, layout, device, pin_memory, optional_memory_format);
871   }
872 }
873 
is_pinned_sparse_coo(const Tensor & self,std::optional<Device> device)874 bool is_pinned_sparse_coo(const Tensor& self, std::optional<Device> device) {
875   // Assuming that _indices has the same pin memory state as _values
876   return self._values().is_pinned(device);
877 }
878 
_pin_memory_sparse_coo(const Tensor & self,std::optional<Device> device)879 Tensor _pin_memory_sparse_coo(const Tensor& self, std::optional<Device> device) {
880   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!device.has_value() || device->is_cuda());
881   // pinning of sparse tensor is equivalent to cloning indices and
882   // values that will not change the sparse tensor invariants. Hence,
883   // we can skip checking the sparse tensor invariants for efficiency.
884   at::sparse_csr::CheckSparseTensorInvariants _(false);
885   TensorOptions options = self.options().pinned_memory(true);
886   return at::_sparse_coo_tensor_with_dims_and_tensors(
887       self.sparse_dim(),
888       self.dense_dim(),
889       self.sizes(),
890       self._indices().pin_memory(device),
891       self._values().pin_memory(device),
892       options,
893       self.is_coalesced());
894 }
895 
896 } // namespace at::native
897