1 #include <ATen/ATen.h>
2 #include <ATen/InitialTensorOptions.h>
3 #include <ATen/SparseCsrTensorImpl.h>
4 #include <ATen/SparseCsrTensorUtils.h>
5 #include <ATen/SparseTensorImpl.h>
6 #include <ATen/core/LegacyTypeDispatch.h>
7 #include <ATen/native/Resize.h>
8
9 namespace at {
10
SparseCsrTensorImpl(at::DispatchKeySet key_set,at::Device device,at::Layout layout,const caffe2::TypeMeta data_type)11 SparseCsrTensorImpl::SparseCsrTensorImpl(
12 at::DispatchKeySet key_set,
13 at::Device device,
14 at::Layout layout,
15 const caffe2::TypeMeta data_type)
16 : SparseCsrTensorImpl(
17 key_set,
18 data_type,
19 at::empty(
20 {0},
21 at::initialTensorOptions()
22 .device(device)
23 .dtype(ScalarType::Int)) // crow_indices
24 ,
25 at::empty(
26 {0},
27 at::initialTensorOptions()
28 .device(device)
29 .dtype(ScalarType::Int)) // col_indices
30 ,
31 at::empty(
32 {0},
33 at::initialTensorOptions()
34 .device(device)
35 .dtype(data_type)) // values
36 ,
37 layout
38 ) {}
39
SparseCsrTensorImpl(at::DispatchKeySet key_set,const caffe2::TypeMeta data_type,at::Tensor crow_indices,at::Tensor col_indices,at::Tensor values,at::Layout layout)40 SparseCsrTensorImpl::SparseCsrTensorImpl(
41 at::DispatchKeySet key_set,
42 const caffe2::TypeMeta data_type,
43 at::Tensor crow_indices,
44 at::Tensor col_indices,
45 at::Tensor values,
46 at::Layout layout)
47 : TensorImpl(key_set, data_type, values.device()),
48 crow_indices_(std::move(crow_indices)),
49 col_indices_(std::move(col_indices)),
50 values_(std::move(values)),
51 layout_(layout) {
52 // https://pytorch.org/blog/pytorch-feature-classification-changes/#beta
53 TORCH_WARN_ONCE("Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensor support is in beta state. "
54 "If you miss a functionality in the sparse tensor support, please submit a feature request "
55 "to https://github.com/pytorch/pytorch/issues.");
56
57 TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU)
58 || (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)
59 || (key_set.has(DispatchKey::SparseCsrMeta) && device().type() == kMeta)
60 || (key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kMeta) // fake tensor
61 || (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kMeta) // fake tensor
62 || (key_set.has(DispatchKey::SparseCsrPrivateUse1) && device().type() == kPrivateUse1)),
63 "Inconsistent key_set (=", key_set, ") and device (=", device(), ")");
64
65 set_storage_access_should_throw();
66 is_non_overlapping_and_dense_ = false;
67 set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
68 // TODO: If this check ever shows up as a bottleneck, which is unlikely given that
69 // comparing devices only involves comparing the type and index (two integers), we
70 // can move this to a DEBUG only assert. Until then this confirms and maintains a
71 // crucial invariance.
72 TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and ",
73 at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device.");
74 TORCH_CHECK(values_.device() == col_indices_.device(), "Values and ",
75 at::sparse_csr::plainIndicesName(layout_), " need to be on the same device.");
76 TORCH_INTERNAL_ASSERT(values_.device() == device(),
77 "Values and compressed sparse tensor instance need to have the same device.");
78 }
79
tensorimpl_type_name() const80 const char* SparseCsrTensorImpl::tensorimpl_type_name() const {
81 return "SparseCsrTensorImpl";
82 }
83
resize_(int64_t nnz,IntArrayRef size)84 void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
85 TORCH_CHECK(
86 !has_symbolic_sizes_strides_,
87 "resize_ called on tensor with symbolic shape")
88 auto rows = size[size.size() - 2];
89 auto cols = size[size.size() - 1];
90 auto old_crow_indices_size = crow_indices_.size(-1);
91
92 auto new_crow_indices_size = DimVector(size.slice(0, size.size() - 2));
93 new_crow_indices_size.push_back(rows + 1);
94 crow_indices_.resize_(new_crow_indices_size);
95 if (rows + 1 >= old_crow_indices_size) {
96 crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
97 } else {
98 crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
99 }
100 auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
101 col_indices_values_size.push_back(std::min<int64_t>(nnz, rows*cols));
102 col_indices_.resize_(col_indices_values_size);
103 values_.resize_(col_indices_values_size);
104 sizes_and_strides_.set_sizes(size);
105 refresh_numel();
106 }
107
resize_and_clear_(int64_t sparse_dim,int64_t dense_dim,IntArrayRef size)108 void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
109 TORCH_CHECK(
110 !has_symbolic_sizes_strides_,
111 "resize_and_clear_ called on tensor with symbolic shape");
112 TORCH_CHECK(sparse_dim == 2, "resize_and_clear_ sparse dimensionality must be 2, got ", sparse_dim);
113 TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim + dense_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
114 sparse_dim, ") plus dense dimensionality (=", dense_dim, "), got ", size.size());
115 auto batch_dim = size.size() - sparse_dim - dense_dim;
116 auto batchsize = size.slice(0, batch_dim);
117 auto densesize = size.slice(batch_dim + sparse_dim, dense_dim);
118
119 auto col_indices_size = DimVector(batchsize);
120 col_indices_size.push_back(0); // nse
121
122 auto n_compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout_, "resize_and_clear_",
123 [&] () -> int64_t { return size[batch_dim]; },
124 [&] () -> int64_t { return size[batch_dim + 1]; }
125 );
126 auto values_size = DimVector(batchsize);
127 values_size.push_back(0); // nse
128 // WARNING: in the case of block tensors, the block size is defined
129 // by the existing values shape.
130 int64_t block_factor = 1;
131 AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_,
132 "resize_and_clear_",
133 [] () {},
134 [&] () {
135 auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2);
136 values_size.append(blocksize.begin(), blocksize.end());
137 block_factor = blocksize[(the_layout == kSparseBsr ? 0 : 1)];
138
139 });
140 TORCH_CHECK(n_compressed_indices % block_factor == 0,
141 "The size of the compressed dimension (=", n_compressed_indices,
142 ") must be divisible with the corresponding block size (=", block_factor,")");
143 n_compressed_indices /= block_factor;
144 values_size.append(densesize.begin(), densesize.end());
145
146 auto crow_indices_size = DimVector(batchsize);
147 crow_indices_size.push_back(n_compressed_indices + 1);
148
149 crow_indices_.resize_(crow_indices_size);
150 crow_indices_.zero_();
151 col_indices_.resize_(col_indices_size);
152 values_.resize_(values_size);
153 sizes_and_strides_.set_sizes(size);
154 refresh_numel();
155 }
156
resize_as_sparse_compressed_tensor_(const Tensor & src)157 void SparseCsrTensorImpl::resize_as_sparse_compressed_tensor_(
158 const Tensor& src) {
159 TORCH_CHECK(
160 !has_symbolic_sizes_strides_,
161 "resize_as_sparse_compressed_tensor_ called on tensor with symbolic shape");
162
163 // We cannot resize as other layout and preserve the invariants for self
164 // layout
165 TORCH_CHECK(
166 src.layout() == layout_,
167 "resize_as_sparse_compressed_tensor_: self and src must have the same layout, but got: self (",
168 layout_,
169 ") and source (",
170 src.layout(),
171 ")");
172
173 auto [compressed_indices, plain_indices] =
174 sparse_csr::getCompressedPlainIndices(src);
175 // reuse self indices storage
176 if (crow_indices_.sizes() != compressed_indices.sizes()) {
177 crow_indices_.resize_as_(compressed_indices);
178 }
179 if (col_indices_.sizes() != plain_indices.sizes()) {
180 col_indices_.resize_as_(plain_indices);
181 }
182 // Update indices data to ensure result is valid under invariants check
183 if ((sizes() != src.sizes()) || (dense_dim() != src.dense_dim())) {
184 crow_indices_.copy_(compressed_indices);
185 col_indices_.copy_(plain_indices);
186 }
187 // Reuse values storage
188 if (values_.sizes() != src.values().sizes()) {
189 values_.resize_as_(src.values());
190 }
191 sizes_and_strides_.set_sizes(src.sizes());
192 refresh_numel();
193 }
194
set_member_tensors(const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,c10::SymIntArrayRef size)195 void SparseCsrTensorImpl::set_member_tensors(
196 const Tensor& crow_indices,
197 const Tensor& col_indices,
198 const Tensor& values,
199 c10::SymIntArrayRef size) {
200 TORCH_CHECK(
201 !has_symbolic_sizes_strides_,
202 "set_member_tensors called on tensor with symbolic shape");
203
204 // CSR Type Invariants
205 TORCH_CHECK(
206 values.scalar_type() == typeMetaToScalarType(dtype()),
207 "dtype of values (",
208 values.scalar_type(),
209 ") must match dtype of sparse tensor (",
210 typeMetaToScalarType(dtype()),
211 ")");
212 crow_indices_ = crow_indices;
213 col_indices_ = col_indices;
214 values_ = values;
215
216 sizes_and_strides_.set_sizes(C10_AS_INTARRAYREF_SLOW(size));
217 refresh_numel();
218 // TODO: If this check ever shows up as a bottleneck, which is unlikely given that
219 // comparing devices only involves comparing the type and index (two integers), we
220 // can move this to a DEBUG only assert. Until then this confirms and maintains a
221 // crucial invariance.
222 TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and ",
223 at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device.");
224 TORCH_CHECK(values_.device() == col_indices_.device(), "Values and ",
225 at::sparse_csr::plainIndicesName(layout_), " need to be on the same device.");
226 TORCH_CHECK(values_.device() == device(),
227 "Values and compressed tensor instance need to be on the same device.");
228 }
229
set_member_tensors(const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,IntArrayRef size)230 void SparseCsrTensorImpl::set_member_tensors(
231 const Tensor& crow_indices,
232 const Tensor& col_indices,
233 const Tensor& values,
234 IntArrayRef size) {
235 set_member_tensors(crow_indices, col_indices, values, c10::fromIntArrayRefSlow(size));
236 }
237
strides_custom() const238 IntArrayRef SparseCsrTensorImpl::strides_custom() const {
239 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have strides");
240 }
sym_strides_custom() const241 SymIntArrayRef SparseCsrTensorImpl::sym_strides_custom() const {
242 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have strides");
243 }
set_size(int64_t dim,int64_t new_size)244 void SparseCsrTensorImpl::set_size(int64_t dim, int64_t new_size) {
245 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_size.");
246 }
set_stride(int64_t dim,int64_t new_stride)247 void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
248 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_stride.");
249 }
set_storage_offset(int64_t storage_offset)250 void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) {
251 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset.");
252 }
is_contiguous_custom(MemoryFormat) const253 bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const {
254 TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous");
255 }
256
257 } // namespace at
258