1 #pragma once
2
3 #include <ATen/Parallel.h>
4 #include <ATen/SparseTensorImpl.h>
5 #include <ATen/core/Tensor.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #else
10 #include <ATen/ops/empty.h>
11 #include <ATen/ops/tensor.h>
12 #endif
13
14 namespace at::sparse {
15
16 // Just for documentary purposes
17 using SparseTensor = Tensor;
18 using SparseType = Type;
19
20 // This is an internal utility function for getting at the SparseTensorImpl,
21 // so that we can write sparse tensor specific accessors for special fields
22 // in SparseTensor. You should only use this for writing low level
23 // setters/getters for SparseTensorImpl fields; otherwise, you should use
24 // the low level setters/getters that were implemented using this.
25 //
26 // This may be called repeatedly, so make sure it's pretty cheap.
get_sparse_impl(const SparseTensor & self)27 inline SparseTensorImpl* get_sparse_impl(const SparseTensor& self) {
28 TORCH_INTERNAL_ASSERT(
29 self.is_sparse(), "_internal_get_SparseTensorImpl: not a sparse tensor");
30 return static_cast<SparseTensorImpl*>(self.unsafeGetTensorImpl());
31 }
32
33 // Takes indices and values and directly puts them into the sparse tensor, no
34 // copy. This used to be called THSTensor_(_move)
alias_into_sparse(const SparseTensor & self,const Tensor & indices,const Tensor & values)35 inline void alias_into_sparse(
36 const SparseTensor& self,
37 const Tensor& indices,
38 const Tensor& values) {
39 get_sparse_impl(self)->set_indices_and_values_unsafe(indices, values);
40 }
41
42 // Take indices and values and makes a (data) copy of them to put into the
43 // sparse indices/values. This used to be called THSTensor_(_set)
copy_into_sparse(const SparseTensor & self,const Tensor & indices,const Tensor & values,bool non_blocking)44 inline void copy_into_sparse(
45 const SparseTensor& self,
46 const Tensor& indices,
47 const Tensor& values,
48 bool non_blocking) {
49 alias_into_sparse(
50 self,
51 indices.to(self._indices().options(), non_blocking, /*copy=*/true),
52 values.to(self._values().options(), non_blocking, /*copy=*/true));
53 }
54
55 // TODO: put this into the public API
is_same_tensor(const Tensor & lhs,const Tensor & rhs)56 inline bool is_same_tensor(const Tensor& lhs, const Tensor& rhs) {
57 return lhs.unsafeGetTensorImpl() == rhs.unsafeGetTensorImpl();
58 }
59
is_same_density(const SparseTensor & self,const SparseTensor & src)60 inline bool is_same_density(const SparseTensor& self, const SparseTensor& src) {
61 return self.sparse_dim() == src.sparse_dim() &&
62 self.dense_dim() == src.dense_dim();
63 }
64
65 // Give us a new values tensor, with the same dimensionality
66 // as 'values' but with a new number of non-zero elements.
67 // TODO: Expose this for real in ATen, some day?
68 // NB: Doesn't preserve data.
new_values_with_size_of(const Tensor & values,int64_t nnz)69 inline Tensor new_values_with_size_of(const Tensor& values, int64_t nnz) {
70 std::vector<int64_t> size = values.sizes().vec();
71 size[0] = nnz;
72 return at::empty(size, values.options());
73 }
74
75 // NOTE [ Flatten Sparse Indices ]
76 // This helper function flattens a sparse indices tensor (a Tensor) into a 1D
77 // indices tensor. E.g.,
78 // input = [[2, 4, 0],
79 // [3, 1, 10]]
80 // full_size = [2, 12]
81 // output = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 10 ] = [27, 49, 10]
82 //
83 // In other words, assuming that each `indices[i, :]` is a valid index to a
84 // tensor `t` of shape `full_size`. This returns the corresponding indices to
85 // the flattened tensor `t.reshape( prod(full_size[:indices.size(0)]), -1 )`.
86 // if forceClone is true, the result will forced to be a clone of self.
87 // if force_clone is true, the result will forced to be a clone of self.
88 TORCH_API Tensor flatten_indices(
89 const Tensor& indices,
90 IntArrayRef full_size,
91 bool force_clone = false);
92
93 // Flatten sparse tensor's indices from nD to 1D, similar to NOTE [ Flatten
94 // Sparse Indices ], except this one allows partial flatten: only flatten on
95 // specified dims. Note that the flatten indices might be uncoalesced if
96 // dims_to_flatten.size() < sparse_dim. Also if input indices is already
97 // coalesced, the flattened indices will also be sorted.
98 //
99 // args:
100 // indices: sparse tensor indices
101 // sizes: sparse tensor sizes
102 // dims_to_flatten: a list of dim index to flatten
103 //
104 // Ex1:
105 // indices = [[2, 4, 0],
106 // [3, 1, 3]]
107 // sizes = [2, 12]
108 // dims_to_flatten = [0, 1]
109 // new_indices = [ 2 * 12 + 3, 4 * 12 + 1, 0 * 12 + 3 ] = [27, 49, 3]
110 //
111 // Ex2:
112 // dims_to_flatten = [1]
113 // new_indices = [ 3, 1, 3 ] # uncoalesced
114 TORCH_API Tensor flatten_indices_by_dims(
115 const Tensor& indices,
116 const IntArrayRef& sizes,
117 const IntArrayRef& dims_to_flatten);
118
119 // Find the CSR representation for a row `indices` from the COO format
120 TORCH_API Tensor coo_to_csr(const int64_t* indices, int64_t dim, int64_t nnz);
121
122 TORCH_API Tensor zeros_like_with_indices(const Tensor& t);
123
124 template <size_t static_shape_max_len>
125 class TensorGeometryHolder {
126 using geometry_holder_t = std::array<int64_t, static_shape_max_len>;
127
128 public:
129 explicit TensorGeometryHolder(
130 IntArrayRef sizes,
131 IntArrayRef strides,
132 TensorOptions options = {}) {
133 std::copy(sizes.begin(), sizes.end(), t_sizes.begin());
134 std::copy(strides.begin(), strides.end(), t_strides.begin());
135 }
136
TensorGeometryHolder(const Tensor & t)137 explicit TensorGeometryHolder(const Tensor& t)
138 : TensorGeometryHolder(t.sizes(), t.strides()) {}
139
140 auto operator*() const {
141 return std::make_tuple(t_sizes, t_strides);
142 }
143
144 private:
145 geometry_holder_t t_sizes;
146 geometry_holder_t t_strides;
147 };
148
149 template <>
150 class TensorGeometryHolder<0> {
151 using geometry_holder_t = Tensor;
152
153 public:
TensorGeometryHolder(IntArrayRef sizes,IntArrayRef strides,TensorOptions options)154 explicit TensorGeometryHolder(
155 IntArrayRef sizes,
156 IntArrayRef strides,
157 TensorOptions options) {
158 const int64_t t_ndims = sizes.size();
159 const auto cpu_options = TensorOptions(options).dtype(kLong).device(kCPU);
160 Tensor t_sizes_and_strides_cpu = at::empty({2, t_ndims}, cpu_options);
161 t_sizes_and_strides_cpu.select(0, 0).copy_(at::tensor(sizes, cpu_options));
162 t_sizes_and_strides_cpu.select(0, 1).copy_(
163 at::tensor(strides, cpu_options));
164 const Tensor t_sizes_and_strides =
165 t_sizes_and_strides_cpu.to(options.device());
166 t_sizes = t_sizes_and_strides.select(0, 0);
167 t_strides = t_sizes_and_strides.select(0, 1);
168 }
169
TensorGeometryHolder(const Tensor & t)170 explicit TensorGeometryHolder(const Tensor& t)
171 : TensorGeometryHolder(t.sizes(), t.strides(), t.options()) {}
172
173 auto operator*() const {
174 return std::make_tuple(
175 t_sizes.template data_ptr<int64_t>(),
176 t_strides.template data_ptr<int64_t>());
177 }
178
179 private:
180 geometry_holder_t t_sizes;
181 geometry_holder_t t_strides;
182 };
183
184 // Return all indices of a tensor with the given shape.
185 //
186 // full_coo_indices(shape) is equivalent to
187 // torch.ones(shape).nonzero().transpose(-2, -1) but much faster.
188 TORCH_API Tensor full_coo_indices(IntArrayRef sizes, TensorOptions options);
189
190 } // namespace at::sparse
191