xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/SparseTensorUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Parallel.h>
4 #include <ATen/SparseTensorImpl.h>
5 #include <ATen/core/Tensor.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #include <ATen/ops/tensor.h>
12 #endif
13 
14 namespace at::sparse {
15 
16 // Just for documentary purposes
17 using SparseTensor = Tensor;
18 using SparseType = Type;
19 
20 // This is an internal utility function for getting at the SparseTensorImpl,
21 // so that we can write sparse tensor specific accessors for special fields
22 // in SparseTensor.  You should only use this for writing low level
23 // setters/getters for SparseTensorImpl fields; otherwise, you should use
24 // the low level setters/getters that were implemented using this.
25 //
26 // This may be called repeatedly, so make sure it's pretty cheap.
get_sparse_impl(const SparseTensor & self)27 inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
28   TORCH_INTERNAL_ASSERT(
29       self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
30   return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
31 }
32 
33 // Takes indices and values and directly puts them into the sparse tensor, no
34 // copy.  This used to be called THSTensor_(_move)
alias_into_sparse(const SparseTensor & self,const Tensor & indices,const Tensor & values)35 inline void alias_into_sparse(
36     const SparseTensor& self,
37     const Tensor& indices,
38     const Tensor& values) {
39   get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
40 }
41 
42 // Take indices and values and makes a (data) copy of them to put into the
43 // sparse indices/values.  This used to be called THSTensor_(_set)
copy_into_sparse(const SparseTensor & self,const Tensor & indices,const Tensor & values,bool non_blocking)44 inline void copy_into_sparse(
45     const SparseTensor& self,
46     const Tensor& indices,
47     const Tensor& values,
48     bool non_blocking) {
49   alias_into_sparse(
50       self,
51       indices.to(self._indices().options(), non_blocking, /*copy=*/true),
52       values.to(self._values().options(), non_blocking, /*copy=*/true));
53 }
54 
55 // TODO: put this into the public API
is_same_tensor(const Tensor & lhs,const Tensor & rhs)56 inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
57   return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
58 }
59 
is_same_density(const SparseTensor & self,const SparseTensor & src)60 inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
61   return self.sparse_dim() == src.sparse_dim() &&
62       self.dense_dim() == src.dense_dim();
63 }
64 
65 // Give us a new values tensor, with the same dimensionality
66 // as 'values' but with a new number of non-zero elements.
67 // TODO: Expose this for real in ATen, some day?
68 // NB: Doesn't preserve data.
new_values_with_size_of(const Tensor & values,int64_t nnz)69 inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
70   std::vector<int64_t> size = values.sizes().vec();
71   size[0] = nnz;
72   return at::empty(size, values.options());
73 }
74 
75 // NOTE [ Flatten Sparse Indices ]
76 // This helper function flattens a sparse indices tensor (a Tensor) into a 1D
77 // indices tensor. E.g.,
78 //   input = [[2, 4, 0],
79 //            [3, 1, 10]]
80 //   full_size = [2, 12]
81 //   output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
82 //
83 // In other words, assuming that each `indices[i, :]` is a valid index to a
84 // tensor `t` of shape `full_size`. This returns the corresponding indices to
85 // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
86 // if forceClone is true, the result will forced to be a clone of self.
87 // if force_clone is true, the result will forced to be a clone of self.
88 TORCH_API Tensor flatten_indices(
89     const Tensor& indices,
90     IntArrayRef full_size,
91     bool force_clone = false);
92 
93 // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
94 // Sparse Indices ], except this one allows partial flatten: only flatten on
95 // specified dims. Note that the flatten indices might be uncoalesced if
96 // dims_to_flatten.size() < sparse_dim. Also if input indices is already
97 // coalesced, the flattened indices will also be sorted.
98 //
99 // args:
100 //    indices: sparse tensor indices
101 //    sizes: sparse tensor sizes
102 //    dims_to_flatten: a list of dim index to flatten
103 //
104 // Ex1:
105 //   indices = [[2, 4, 0],
106 //             [3, 1, 3]]
107 //   sizes = [2, 12]
108 //   dims_to_flatten = [0, 1]
109 //   new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
110 //
111 // Ex2:
112 //   dims_to_flatten = [1]
113 //   new_indices = [ 3, 1, 3 ]  # uncoalesced
114 TORCH_API Tensor flatten_indices_by_dims(
115     const Tensor& indices,
116     const IntArrayRef& sizes,
117     const IntArrayRef& dims_to_flatten);
118 
119 // Find the CSR representation for a row `indices` from the COO format
120 TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
121 
122 TORCH_API Tensor zeros_like_with_indices(const Tensor& t);
123 
124 template <size_t static_shape_max_len>
125 class TensorGeometryHolder {
126   using geometry_holder_t = std::array<int64_t, static_shape_max_len>;
127 
128  public:
129   explicit TensorGeometryHolder(
130       IntArrayRef sizes,
131       IntArrayRef strides,
132       TensorOptions options = {}) {
133     std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
134     std::copy(strides.begin(), strides.end(), t_strides.begin());
135   }
136 
TensorGeometryHolder(const Tensor & t)137   explicit TensorGeometryHolder(const Tensor& t)
138       : TensorGeometryHolder(t.sizes(), t.strides()) {}
139 
140   auto operator*() const {
141     return std::make_tuple(t_sizes, t_strides);
142   }
143 
144  private:
145   geometry_holder_t t_sizes;
146   geometry_holder_t t_strides;
147 };
148 
149 template <>
150 class TensorGeometryHolder<0> {
151   using geometry_holder_t = Tensor;
152 
153  public:
TensorGeometryHolder(IntArrayRef sizes,IntArrayRef strides,TensorOptions options)154   explicit TensorGeometryHolder(
155       IntArrayRef sizes,
156       IntArrayRef strides,
157       TensorOptions options) {
158     const int64_t t_ndims = sizes.size();
159     const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
160     Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
161     t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
162     t_sizes_and_strides_cpu.select(0, 1).copy_(
163         at::tensor(strides, cpu_options));
164     const Tensor t_sizes_and_strides =
165         t_sizes_and_strides_cpu.to(options.device());
166     t_sizes = t_sizes_and_strides.select(0, 0);
167     t_strides = t_sizes_and_strides.select(0, 1);
168   }
169 
TensorGeometryHolder(const Tensor & t)170   explicit TensorGeometryHolder(const Tensor& t)
171       : TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}
172 
173   auto operator*() const {
174     return std::make_tuple(
175         t_sizes.template data_ptr<int64_t>(),
176         t_strides.template data_ptr<int64_t>());
177   }
178 
179  private:
180   geometry_holder_t t_sizes;
181   geometry_holder_t t_strides;
182 };
183 
184 // Return all indices of a tensor with the given shape.
185 //
186 // full_coo_indices(shape) is equivalent to
187 // torch.ones(shape).nonzero().transpose(-2, -1) but much faster.
188 TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options);
189 
190 } // namespace at::sparse
191