1 #include <ATen/native/SparseTensorUtils.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/SparseTensorImpl.h>
5 #include <ATen/native/sparse/SparseStubs.h>
6 #include <ATen/Parallel.h>
7 #include <c10/util/irange.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
13 #include <ATen/ops/zeros.h>
14 #endif
15
16 namespace at::native {
17
18 DEFINE_DISPATCH(flatten_indices_stub);
19
20 } // namespace at::native
21
22 namespace at::sparse {
23
24 // NOTE [ Flatten Sparse Indices ]
25 // This helper function flattens a sparse indices tensor (a Tensor) into a 1D
26 // indices tensor. E.g.,
27 // input = [[2, 4, 0],
28 // [3, 1, 10]]
29 // full_size = [2, 12]
30 // output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
31 //
32 // In other words, assuming that each `indices[i, :]` is a valid index to a
33 // tensor `t` of shape `full_size`. This returns the corresponding indices to
34 // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
35 // if forceClone is true, the result will forced to be a clone of self.
36 // if force_clone is true, the result will forced to be a clone of self.
flatten_indices(const Tensor & indices,IntArrayRef full_size,bool force_clone)37 Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone /*= false*/) {
38 int64_t sparse_dim = indices.size(0);
39 if (sparse_dim == 1) {
40 if (force_clone) {
41 return indices.squeeze(0).clone(at::MemoryFormat::Contiguous);
42 } else {
43 return indices.squeeze(0);
44 }
45 } else {
46 if (!indices.numel()) {
47 return at::zeros({indices.size(1)}, indices.options().dtype(kLong));
48 }
49 return at::native::flatten_indices_stub(indices.device().type(), indices, full_size.slice(0, sparse_dim));
50 }
51 }
52
53 // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ],
54 // except this one allows partial flatten: only flatten on specified dims. Note that
55 // the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim.
56 // Also if input indices is already coalesced, the flattened indices will also be sorted.
57 //
58 // args:
59 // indices: sparse tensor indices
60 // sizes: sparse tensor sizes
61 // dims_to_flatten: a list of dim index to flatten
62 //
63 // Ex1:
64 // indices = [[2, 4, 0],
65 // [3, 1, 3]]
66 // sizes = [2, 12]
67 // dims_to_flatten = [0, 1]
68 // new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
69 //
70 // Ex2:
71 // dims_to_flatten = [1]
72 // new_indices = [ 3, 1, 3 ] # uncoalesced
flatten_indices_by_dims(const Tensor & indices,const IntArrayRef & sizes,const IntArrayRef & dims_to_flatten)73 Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){
74 Tensor new_indices = at::zeros({indices.size(1)}, indices.options());
75 for (auto d : dims_to_flatten) {
76 new_indices.mul_(sizes[d]);
77 new_indices.add_(indices.select(0, d));
78 }
79 return new_indices;
80 }
81
coo_to_csr(const int64_t * indices,int64_t dim,int64_t nnz)82 Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) {
83 /*
84 Find the CSR representation for a row `indices` from the COO format
85 Inputs:
86 `indices` is the row pointer from COO indices
87 `dim` is the row dimensionality
88 `nnz` is the number of non-zeros
89
90 Output:
91 `csr` is a compressed row array in a CSR format
92 */
93 Tensor csr = at::zeros({dim + 1}, kLong);
94
95 // TODO: eliminate this conditional when zero-size dims supported correctly
96 if (nnz > 0) {
97 auto csr_accessor = csr.accessor<int64_t, 1>();
98 // Convert the sparse matrix to CSR format
99 at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) {
100 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
101 int64_t h, hp0, hp1;
102 for (const auto i : c10::irange(start, end)) {
103 hp0 = indices[i];
104 hp1 = (i+1 == nnz) ? dim : indices[i+1];
105 if (hp0 != hp1) {
106 for (h = hp0; h < hp1; h++) {
107 csr_accessor[h+1] = i+1;
108 }
109 }
110 }
111 });
112 }
113 return csr;
114 }
115
zeros_like_with_indices(const Tensor & t)116 Tensor zeros_like_with_indices(const Tensor& t) {
117 TORCH_INTERNAL_ASSERT(t.is_sparse());
118 return at::_sparse_coo_tensor_with_dims_and_tensors(
119 t.sparse_dim(),
120 t.dense_dim(),
121 t.sizes(),
122 t._indices().clone(),
123 at::zeros({1}, t._values().options()).expand_as(t._values()),
124 t.options(),
125 t.is_coalesced());
126 }
127
full_coo_indices(IntArrayRef sizes,TensorOptions options)128 Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options) {
129 const auto max_size = *std::max_element(sizes.begin(), sizes.end());
130 const auto max_size_arange = at::arange(max_size, options);
131 std::vector<Tensor> stack;
132 stack.reserve(sizes.size());
133 for (size_t i=0; i < sizes.size(); i++) {
134 Tensor a = max_size_arange.narrow(-1, 0, sizes[i]);
135 for (size_t j=0; j < sizes.size(); j++) {
136 if (i != j) {
137 a.unsqueeze_(j);
138 }
139 }
140 stack.push_back(a.expand(sizes));
141 }
142 return at::stack(stack).flatten(1, -1);
143 }
144
145 } // namespace at::sparse
146