xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SparseTensorUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/SparseTensorUtils.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/SparseTensorImpl.h>
5 #include <ATen/native/sparse/SparseStubs.h>
6 #include <ATen/Parallel.h>
7 #include <c10/util/irange.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
13 #include <ATen/ops/zeros.h>
14 #endif
15 
16 namespace at::native {
17 
18 DEFINE_DISPATCH(flatten_indices_stub);
19 
20 } // namespace at::native
21 
22 namespace at::sparse {
23 
24 // NOTE [ Flatten Sparse Indices ]
25 // This helper function flattens a sparse indices tensor (a Tensor) into a 1D
26 // indices tensor. E.g.,
27 //   input = [[2, 4, 0],
28 //            [3, 1, 10]]
29 //   full_size = [2, 12]
30 //   output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
31 //
32 // In other words, assuming that each `indices[i, :]` is a valid index to a
33 // tensor `t` of shape `full_size`. This returns the corresponding indices to
34 // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
35 // if forceClone is true, the result will forced to be a clone of self.
36 // if force_clone is true, the result will forced to be a clone of self.
flatten_indices(const Tensor & indices,IntArrayRef full_size,bool force_clone)37 Tensor flatten_indices(const Tensor& indices, IntArrayRef full_size, bool force_clone /*= false*/) {
38   int64_t sparse_dim = indices.size(0);
39   if (sparse_dim == 1) {
40     if (force_clone) {
41       return indices.squeeze(0).clone(at::MemoryFormat::Contiguous);
42     } else {
43       return indices.squeeze(0);
44     }
45   } else {
46     if (!indices.numel()) {
47       return at::zeros({indices.size(1)}, indices.options().dtype(kLong));
48     }
49     return at::native::flatten_indices_stub(indices.device().type(), indices, full_size.slice(0, sparse_dim));
50   }
51 }
52 
53 // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten Sparse Indices ],
54 // except this one allows partial flatten: only flatten on specified dims. Note that
55 // the flatten indices might be uncoalesced if dims_to_flatten.size() < sparse_dim.
56 // Also if input indices is already coalesced, the flattened indices will also be sorted.
57 //
58 // args:
59 //    indices: sparse tensor indices
60 //    sizes: sparse tensor sizes
61 //    dims_to_flatten: a list of dim index to flatten
62 //
63 // Ex1:
64 //   indices = [[2, 4, 0],
65 //             [3, 1, 3]]
66 //   sizes = [2, 12]
67 //   dims_to_flatten = [0, 1]
68 //   new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
69 //
70 // Ex2:
71 //   dims_to_flatten = [1]
72 //   new_indices = [ 3, 1, 3 ]  # uncoalesced
flatten_indices_by_dims(const Tensor & indices,const IntArrayRef & sizes,const IntArrayRef & dims_to_flatten)73 Tensor flatten_indices_by_dims(const Tensor& indices, const IntArrayRef& sizes, const IntArrayRef& dims_to_flatten){
74   Tensor new_indices = at::zeros({indices.size(1)}, indices.options());
75   for (auto d : dims_to_flatten) {
76     new_indices.mul_(sizes[d]);
77     new_indices.add_(indices.select(0, d));
78   }
79   return new_indices;
80 }
81 
coo_to_csr(const int64_t * indices,int64_t dim,int64_t nnz)82 Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz) {
83   /*
84     Find the CSR representation for a row `indices` from the COO format
85     Inputs:
86       `indices` is the row pointer from COO indices
87       `dim` is the row dimensionality
88       `nnz` is the number of non-zeros
89 
90     Output:
91       `csr` is a compressed row array in a CSR format
92   */
93   Tensor csr = at::zeros({dim + 1}, kLong);
94 
95   // TODO: eliminate this conditional when zero-size dims supported correctly
96   if (nnz > 0) {
97     auto csr_accessor = csr.accessor<int64_t, 1>();
98     // Convert the sparse matrix to CSR format
99     at::parallel_for(0, nnz, 10000, [&](int64_t start, int64_t end) {
100       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
101       int64_t h, hp0, hp1;
102       for (const auto i : c10::irange(start, end)) {
103         hp0 = indices[i];
104         hp1 = (i+1 == nnz) ?  dim : indices[i+1];
105         if (hp0 != hp1) {
106           for (h = hp0; h < hp1; h++) {
107             csr_accessor[h+1] = i+1;
108           }
109         }
110       }
111     });
112   }
113   return csr;
114 }
115 
zeros_like_with_indices(const Tensor & t)116 Tensor zeros_like_with_indices(const Tensor& t) {
117   TORCH_INTERNAL_ASSERT(t.is_sparse());
118   return at::_sparse_coo_tensor_with_dims_and_tensors(
119       t.sparse_dim(),
120       t.dense_dim(),
121       t.sizes(),
122       t._indices().clone(),
123       at::zeros({1}, t._values().options()).expand_as(t._values()),
124       t.options(),
125       t.is_coalesced());
126 }
127 
full_coo_indices(IntArrayRef sizes,TensorOptions options)128 Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options) {
129   const auto max_size = *std::max_element(sizes.begin(), sizes.end());
130   const auto max_size_arange = at::arange(max_size, options);
131   std::vector<Tensor> stack;
132   stack.reserve(sizes.size());
133   for (size_t i=0; i < sizes.size(); i++) {
134     Tensor a = max_size_arange.narrow(-1, 0, sizes[i]);
135     for (size_t j=0; j < sizes.size(); j++) {
136       if (i != j) {
137         a.unsqueeze_(j);
138       }
139     }
140     stack.push_back(a.expand(sizes));
141   }
142   return at::stack(stack).flatten(1, -1);
143 }
144 
145 } // namespace at::sparse
146