xref: /aosp_15_r20/external/pytorch/aten/src/ATen/SparseCsrTensorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/InitialTensorOptions.h>
3 #include <ATen/SparseCsrTensorImpl.h>
4 #include <ATen/SparseCsrTensorUtils.h>
5 #include <ATen/SparseTensorImpl.h>
6 #include <ATen/core/LegacyTypeDispatch.h>
7 #include <ATen/native/Resize.h>
8 
9 namespace at {
10 
SparseCsrTensorImpl(at::DispatchKeySet key_set,at::Device device,at::Layout layout,const caffe2::TypeMeta data_type)11 SparseCsrTensorImpl::SparseCsrTensorImpl(
12     at::DispatchKeySet key_set,
13     at::Device device,
14     at::Layout layout,
15     const caffe2::TypeMeta data_type)
16     : SparseCsrTensorImpl(
17           key_set,
18           data_type,
19           at::empty(
20               {0},
21               at::initialTensorOptions()
22                   .device(device)
23                   .dtype(ScalarType::Int)) // crow_indices
24           ,
25           at::empty(
26               {0},
27               at::initialTensorOptions()
28                   .device(device)
29                   .dtype(ScalarType::Int)) // col_indices
30           ,
31           at::empty(
32               {0},
33               at::initialTensorOptions()
34                   .device(device)
35                   .dtype(data_type)) // values
36           ,
37           layout
38       ) {}
39 
SparseCsrTensorImpl(at::DispatchKeySet key_set,const caffe2::TypeMeta data_type,at::Tensor crow_indices,at::Tensor col_indices,at::Tensor values,at::Layout layout)40 SparseCsrTensorImpl::SparseCsrTensorImpl(
41     at::DispatchKeySet key_set,
42     const caffe2::TypeMeta data_type,
43     at::Tensor crow_indices,
44     at::Tensor col_indices,
45     at::Tensor values,
46     at::Layout layout)
47     : TensorImpl(key_set, data_type, values.device()),
48       crow_indices_(std::move(crow_indices)),
49       col_indices_(std::move(col_indices)),
50       values_(std::move(values)),
51       layout_(layout) {
52   // https://pytorch.org/blog/pytorch-feature-classification-changes/#beta
53   TORCH_WARN_ONCE("Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensor support is in beta state. "
54                   "If you miss a functionality in the sparse tensor support, please submit a feature request "
55                   "to https://github.com/pytorch/pytorch/issues.");
56 
57   TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU)
58                          || (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)
59                          || (key_set.has(DispatchKey::SparseCsrMeta) && device().type() == kMeta)
60                          || (key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kMeta)   // fake tensor
61                          || (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kMeta)  // fake tensor
62                          || (key_set.has(DispatchKey::SparseCsrPrivateUse1) && device().type() == kPrivateUse1)),
63                         "Inconsistent key_set (=", key_set, ") and device (=", device(), ")");
64 
65   set_storage_access_should_throw();
66   is_non_overlapping_and_dense_ = false;
67   set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
68   // TODO: If this check ever shows up as a bottleneck, which is unlikely given that
69   // comparing devices only involves comparing the type and index (two integers), we
70   // can move this to a DEBUG only assert. Until then this confirms and maintains a
71   // crucial invariance.
72   TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and ",
73               at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device.");
74   TORCH_CHECK(values_.device() == col_indices_.device(), "Values and ",
75               at::sparse_csr::plainIndicesName(layout_), " need to be on the same device.");
76   TORCH_INTERNAL_ASSERT(values_.device() == device(),
77                         "Values and compressed sparse tensor instance need to have the same device.");
78 }
79 
tensorimpl_type_name() const80 const char* SparseCsrTensorImpl::tensorimpl_type_name() const {
81   return "SparseCsrTensorImpl";
82 }
83 
resize_(int64_t nnz,IntArrayRef size)84 void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
85   TORCH_CHECK(
86       !has_symbolic_sizes_strides_,
87       "resize_ called on tensor with symbolic shape")
88   auto rows = size[size.size() - 2];
89   auto cols = size[size.size() - 1];
90   auto old_crow_indices_size = crow_indices_.size(-1);
91 
92   auto new_crow_indices_size = DimVector(size.slice(0, size.size() - 2));
93   new_crow_indices_size.push_back(rows + 1);
94   crow_indices_.resize_(new_crow_indices_size);
95   if (rows + 1 >= old_crow_indices_size) {
96     crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
97   } else {
98     crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
99   }
100   auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
101   col_indices_values_size.push_back(std::min<int64_t>(nnz, rows*cols));
102   col_indices_.resize_(col_indices_values_size);
103   values_.resize_(col_indices_values_size);
104   sizes_and_strides_.set_sizes(size);
105   refresh_numel();
106 }
107 
resize_and_clear_(int64_t sparse_dim,int64_t dense_dim,IntArrayRef size)108 void SparseCsrTensorImpl::resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
109   TORCH_CHECK(
110       !has_symbolic_sizes_strides_,
111       "resize_and_clear_ called on tensor with symbolic shape");
112   TORCH_CHECK(sparse_dim == 2, "resize_and_clear_ sparse dimensionality must be 2, got ", sparse_dim);
113   TORCH_CHECK(static_cast<int64_t>(size.size()) >= sparse_dim + dense_dim, "resize_and_clear_ size length must be at least sparse dimensionality (=",
114               sparse_dim, ") plus dense dimensionality (=", dense_dim, "), got ", size.size());
115   auto batch_dim = size.size() - sparse_dim - dense_dim;
116   auto batchsize = size.slice(0, batch_dim);
117   auto densesize = size.slice(batch_dim + sparse_dim, dense_dim);
118 
119   auto col_indices_size = DimVector(batchsize);
120   col_indices_size.push_back(0); // nse
121 
122   auto n_compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(layout_, "resize_and_clear_",
123                                                                         [&] () -> int64_t { return size[batch_dim]; },
124                                                                         [&] () -> int64_t { return size[batch_dim + 1]; }
125                                                                         );
126   auto values_size = DimVector(batchsize);
127   values_size.push_back(0); // nse
128   // WARNING: in the case of block tensors, the block size is defined
129   // by the existing values shape.
130   int64_t block_factor = 1;
131   AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout_,
132                                               "resize_and_clear_",
133                                               [] () {},
134                                               [&] () {
135                                                 auto blocksize = this->values_.sizes().slice(this->batch_dim() + 1, 2);
136                                                 values_size.append(blocksize.begin(), blocksize.end());
137                                                 block_factor = blocksize[(the_layout == kSparseBsr ? 0 : 1)];
138 
139                                               });
140   TORCH_CHECK(n_compressed_indices % block_factor == 0,
141               "The size of the compressed dimension (=", n_compressed_indices,
142               ") must be divisible with the corresponding block size (=", block_factor,")");
143   n_compressed_indices /= block_factor;
144   values_size.append(densesize.begin(), densesize.end());
145 
146   auto crow_indices_size = DimVector(batchsize);
147   crow_indices_size.push_back(n_compressed_indices + 1);
148 
149   crow_indices_.resize_(crow_indices_size);
150   crow_indices_.zero_();
151   col_indices_.resize_(col_indices_size);
152   values_.resize_(values_size);
153   sizes_and_strides_.set_sizes(size);
154   refresh_numel();
155 }
156 
resize_as_sparse_compressed_tensor_(const Tensor & src)157 void SparseCsrTensorImpl::resize_as_sparse_compressed_tensor_(
158     const Tensor& src) {
159   TORCH_CHECK(
160       !has_symbolic_sizes_strides_,
161       "resize_as_sparse_compressed_tensor_ called on tensor with symbolic shape");
162 
163   // We cannot resize as other layout and preserve the invariants for self
164   // layout
165   TORCH_CHECK(
166       src.layout() == layout_,
167       "resize_as_sparse_compressed_tensor_: self and src must have the same layout, but got: self (",
168       layout_,
169       ") and source (",
170       src.layout(),
171       ")");
172 
173   auto [compressed_indices, plain_indices] =
174       sparse_csr::getCompressedPlainIndices(src);
175   // reuse self indices storage
176   if (crow_indices_.sizes() != compressed_indices.sizes()) {
177     crow_indices_.resize_as_(compressed_indices);
178   }
179   if (col_indices_.sizes() != plain_indices.sizes()) {
180     col_indices_.resize_as_(plain_indices);
181   }
182   // Update indices data to ensure result is valid under invariants check
183   if ((sizes() != src.sizes()) || (dense_dim() != src.dense_dim())) {
184     crow_indices_.copy_(compressed_indices);
185     col_indices_.copy_(plain_indices);
186   }
187   // Reuse values storage
188   if (values_.sizes() != src.values().sizes()) {
189     values_.resize_as_(src.values());
190   }
191   sizes_and_strides_.set_sizes(src.sizes());
192   refresh_numel();
193 }
194 
set_member_tensors(const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,c10::SymIntArrayRef size)195 void SparseCsrTensorImpl::set_member_tensors(
196     const Tensor& crow_indices,
197     const Tensor& col_indices,
198     const Tensor& values,
199     c10::SymIntArrayRef size) {
200   TORCH_CHECK(
201       !has_symbolic_sizes_strides_,
202       "set_member_tensors called on tensor with symbolic shape");
203 
204   // CSR Type Invariants
205   TORCH_CHECK(
206       values.scalar_type() == typeMetaToScalarType(dtype()),
207       "dtype of values (",
208       values.scalar_type(),
209       ") must match dtype of sparse tensor (",
210       typeMetaToScalarType(dtype()),
211       ")");
212   crow_indices_ = crow_indices;
213   col_indices_ = col_indices;
214   values_ = values;
215 
216   sizes_and_strides_.set_sizes(C10_AS_INTARRAYREF_SLOW(size));
217   refresh_numel();
218   // TODO: If this check ever shows up as a bottleneck, which is unlikely given that
219   // comparing devices only involves comparing the type and index (two integers), we
220   // can move this to a DEBUG only assert. Until then this confirms and maintains a
221   // crucial invariance.
222   TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and ",
223               at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device.");
224   TORCH_CHECK(values_.device() == col_indices_.device(), "Values and ",
225               at::sparse_csr::plainIndicesName(layout_), " need to be on the same device.");
226   TORCH_CHECK(values_.device() == device(),
227               "Values and compressed tensor instance need to be on the same device.");
228 }
229 
set_member_tensors(const Tensor & crow_indices,const Tensor & col_indices,const Tensor & values,IntArrayRef size)230 void SparseCsrTensorImpl::set_member_tensors(
231     const Tensor& crow_indices,
232     const Tensor& col_indices,
233     const Tensor& values,
234     IntArrayRef size) {
235   set_member_tensors(crow_indices, col_indices, values, c10::fromIntArrayRefSlow(size));
236 }
237 
strides_custom() const238 IntArrayRef SparseCsrTensorImpl::strides_custom() const {
239   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have strides");
240 }
sym_strides_custom() const241 SymIntArrayRef SparseCsrTensorImpl::sym_strides_custom() const {
242   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have strides");
243 }
set_size(int64_t dim,int64_t new_size)244 void SparseCsrTensorImpl::set_size(int64_t dim, int64_t new_size) {
245   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_size.");
246 }
set_stride(int64_t dim,int64_t new_stride)247 void SparseCsrTensorImpl::set_stride(int64_t dim, int64_t new_stride) {
248   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_stride.");
249 }
set_storage_offset(int64_t storage_offset)250 void SparseCsrTensorImpl::set_storage_offset(int64_t storage_offset) {
251   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have set_storage_offset.");
252 }
is_contiguous_custom(MemoryFormat) const253 bool SparseCsrTensorImpl::is_contiguous_custom(MemoryFormat) const {
254   TORCH_CHECK(false, "Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensors do not have is_contiguous");
255 }
256 
257 } // namespace at
258