1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/TensorIndexing.h>
3 #include <ATen/native/sparse/SparseTensorMath.h>
4
5 #include <c10/util/irange.h>
6 #include <c10/util/MaybeOwned.h>
7 #include <ATen/core/Tensor.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/native/sparse/SparseStubs.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/SparseCsrTensorUtils.h>
12 #include <ATen/SparseTensorImpl.h>
13 #include <ATen/ExpandUtils.h>
14 #include <ATen/ScalarOps.h>
15 #include <ATen/InitialTensorOptions.h>
16 #include <ATen/WrapDimUtilsMulti.h>
17 #include <ATen/native/BinaryOps.h>
18 #include <ATen/native/Copy.h>
19 #include <ATen/native/CPUBlas.h>
20 #include <ATen/native/SparseTensorUtils.h>
21
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/_sparse_addmm.h>
27 #include <ATen/ops/_sparse_addmm_native.h>
28 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
29 #include <ATen/ops/_sparse_mm_native.h>
30 #include <ATen/ops/_sparse_sum.h>
31 #include <ATen/ops/_sparse_sum_backward_native.h>
32 #include <ATen/ops/_sparse_sum_native.h>
33 #include <ATen/ops/_sparse_sparse_matmul.h>
34 #include <ATen/ops/_sparse_mm_reduce_impl.h>
35 #include <ATen/ops/_sparse_mm_reduce_impl_native.h>
36 #include <ATen/ops/add.h>
37 #include <ATen/ops/add_native.h>
38 #include <ATen/ops/addmm.h>
39 #include <ATen/ops/addmm_native.h>
40 #include <ATen/ops/arange.h>
41 #include <ATen/ops/any.h>
42 #include <ATen/ops/any_native.h>
43 #include <ATen/ops/bmm_native.h>
44 #include <ATen/ops/cat.h>
45 #include <ATen/ops/conj_physical.h>
46 #include <ATen/ops/conj_physical_native.h>
47 #include <ATen/ops/copy_sparse_to_sparse.h>
48 #include <ATen/ops/div.h>
49 #include <ATen/ops/div_native.h>
50 #include <ATen/ops/empty.h>
51 #include <ATen/ops/empty_like.h>
52 #include <ATen/ops/floor_divide.h>
53 #include <ATen/ops/floor_divide_native.h>
54 #include <ATen/ops/hspmm_native.h>
55 #include <ATen/ops/mm_native.h>
56 #include <ATen/ops/mul.h>
57 #include <ATen/ops/mul_native.h>
58 #include <ATen/ops/mv_native.h>
59 #include <ATen/ops/native_norm_native.h>
60 #include <ATen/ops/neg_native.h>
61 #include <ATen/ops/pow.h>
62 #include <ATen/ops/pow_native.h>
63 #include <ATen/ops/result_type.h>
64 #include <ATen/ops/scalar_tensor.h>
65 #include <ATen/ops/smm_native.h>
66 #include <ATen/ops/sspaddmm.h>
67 #include <ATen/ops/sspaddmm_native.h>
68 #include <ATen/ops/sub_native.h>
69 #include <ATen/ops/zero_native.h>
70 #include <ATen/ops/zeros.h>
71 #include <ATen/ops/zeros_like.h>
72 #include <ATen/ops/zeros_native.h>
73 #include <ATen/ops/index.h>
74 #endif
75
76 #include <algorithm>
77
78 namespace at::native {
79
80 using namespace at::sparse;
81 // --------------------------------------------------------------------
82 // zero_(SparseTensor)
83 // --------------------------------------------------------------------
84
85 // hummu hummu
zero_sparse_(SparseTensor & self)86 SparseTensor& zero_sparse_(SparseTensor& self) {
87 AT_ASSERT(self.is_sparse());
88 self.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
89 return self._coalesced_(true);
90 }
91
92 // NB: Don't need zeros, zeros_like, already implemented in TensorFactories
93
94 // --------------------------------------------------------------------
95 // mul(SparseTensor, Scalar)
96 // --------------------------------------------------------------------
97
mul_out_sparse_zerodim(SparseTensor & r,const SparseTensor & t,const Tensor & value)98 SparseTensor& mul_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) {
99 AT_ASSERT(r.is_sparse());
100 AT_ASSERT(t.is_sparse());
101 AT_ASSERT(value.dim() == 0);
102
103 // Resolve a possibly sparse COO value to a strided tensor.
104 Tensor value_;
105 if (value.is_sparse()) {
106 if (value._nnz() == 0) {
107 r.resize_as_(t);
108 return r.zero_();
109 }
110 value_ = value.values();
111 } else {
112 value_ = value;
113 }
114 // With broadcasting in action, value_ may be a 1-D tensor as long
115 // as its shape is (1,).
116 AT_ASSERT(value_.numel() == 1);
117
118 if (is_same_tensor(r, t)) {
119 r._values().mul_(value_);
120 } else {
121 r.resize_as_(t);
122 auto indices = r._indices();
123 indices.resize_as_(t._indices());
124 indices.copy_(t._indices());
125 Tensor r_values = r._values(); // Sigh... needed because mul_out takes Tensor&
126 at::mul_out(r_values, t._values(), value_);
127 get_sparse_impl(r)->set_nnz_and_narrow(t._nnz());
128 r._coalesced_(t.is_coalesced());
129 }
130 return r;
131 }
132
mul_out_sparse_scalar(SparseTensor & r,const SparseTensor & t,const Scalar & value)133 SparseTensor& mul_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, const Scalar& value) {
134 return mul_out_sparse_zerodim(r, t, wrapped_scalar_tensor(value));
135 }
136
137 // --------------------------------------------------------------------
138 // neg(SparseTensor)
139 // --------------------------------------------------------------------
140
neg_out_sparse(const SparseTensor & t,SparseTensor & r)141 SparseTensor& neg_out_sparse(const SparseTensor& t, SparseTensor& r) {
142 TORCH_CHECK(r.is_sparse(), "Tensor should be sparse");
143 TORCH_CHECK(t.is_sparse(), "Tensor should be sparse");
144
145 // copy_sparse_ does not perform the copy if it is the same tensor
146 copy_sparse_to_sparse_(r, t);
147 r._values().neg_();
148 return r;
149 }
150
neg_sparse(const SparseTensor & t)151 SparseTensor neg_sparse(const SparseTensor& t) {
152 SparseTensor r = at::empty_like(t);
153 neg_out_sparse(t, r);
154 return r;
155 }
156
neg_sparse_(SparseTensor & t)157 SparseTensor& neg_sparse_(SparseTensor& t) {
158 return neg_out_sparse(t, t);
159 }
160
161 // --------------------------------------------------------------------
162 // pow(SparseTensor, Scalar)
163 // --------------------------------------------------------------------
164
165 // TODO: add in-place variant
166
pow_out_sparse_scalar(const SparseTensor & t_,const Scalar & value,SparseTensor & r)167 SparseTensor& pow_out_sparse_scalar(const SparseTensor& t_, const Scalar& value, SparseTensor& r) {
168 AT_ASSERT(r.is_sparse());
169 AT_ASSERT(t_.is_sparse());
170 TORCH_CHECK(value.toDouble() != 0, "pow: cannot raise to zeroth power on sparse tensor; it would make the result tensor dense");
171
172 // This coalesce is why we can't easily provide an inplace variant
173 SparseTensor t = t_.coalesce();
174
175 r.resize_as_(t);
176 auto indices = r._indices();
177 indices.resize_as_(t._indices());
178 indices.copy_(t._indices());
179 Tensor r_values = r._values(); // Sigh... needed because pow_out takes Tensor&
180 at::pow_out(r_values, t._values(), value);
181 get_sparse_impl(r)->set_nnz_and_narrow(t._nnz());
182 return r._coalesced_(t.is_coalesced());
183 }
184
pow_sparse_scalar(const SparseTensor & t,const Scalar & value)185 SparseTensor pow_sparse_scalar(const SparseTensor& t, const Scalar& value) {
186 SparseTensor r = at::empty({0}, t.options());
187 pow_out_sparse_scalar(t, value, r);
188 return r;
189 }
190
191 // --------------------------------------------------------------------
192 // coalesce(SparseTensor)
193 // --------------------------------------------------------------------
194
coalesce_(SparseTensor & tensor)195 static SparseTensor& coalesce_(SparseTensor& tensor) {
196 if (tensor.is_coalesced()) {
197 return tensor;
198 }
199
200 SparseTensor coalesced = tensor.coalesce();
201 tensor._values().resize_as_(coalesced._values());
202 tensor._indices().resize_as_(coalesced._indices());
203 tensor._values().copy_(coalesced._values());
204 tensor._indices().copy_(coalesced._indices());
205 tensor._coalesced_(true);
206 return tensor;
207 }
208
209 // Note [Sparse Floor Division]
210 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
211 // Uncoalesced sparse tensors cannot be floor divided correctly. Integer
212 // division is considered a special-case of floor division for purposes of
213 // this note.
214 // For example, an integer tensor with values=[3, 3] divided by 2 would produce
215 // values=[1, 1], which sum to 2 instead of 3 (=6/2).
216 // A float tensor with values=[3., 3.] floor divided by 2 would also produce
217 // values=[1., 1.] (after truncation), which sum to 2.f instead of 3.f.
218 // To perform floor division the sparse tensor must be coalesced first.
219 // --------------------------------------------------------------------
220 // div(SparseTensor, Scalar)
221 // --------------------------------------------------------------------
222
div_out_sparse_zerodim(const SparseTensor & t,const Tensor & value,std::optional<c10::string_view> rounding_mode,SparseTensor & r)223 SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, std::optional<c10::string_view> rounding_mode, SparseTensor& r) {
224 TORCH_CHECK(value.dim() == 0, "Sparse division requires a scalar or ",
225 "zero-dim dense tensor divisor (got shape ", value.sizes(), " for divisor)");
226 TORCH_CHECK(!value.is_sparse(), "Sparse division requires a scalar or ",
227 "zero-dim dense tensor divisor (got a sparse divisor)");
228
229 AT_ASSERT(r.is_sparse());
230 AT_ASSERT(t.is_sparse());
231
232 // See note "Sparse Floor Division"
233 const bool should_coalesce = rounding_mode.has_value() && !t.is_coalesced();
234 if (is_same_tensor(r, t)) {
235 if (should_coalesce) {
236 coalesce_(r);
237 }
238 r._values().div_(value, rounding_mode);
239 } else {
240 Tensor t_tmp = t;
241 if (should_coalesce) {
242 t_tmp = t.coalesce();
243 }
244 r.resize_as_(t_tmp);
245 auto indices = r._indices();
246 indices.resize_as_(t_tmp._indices());
247 indices.copy_(t_tmp._indices());
248 Tensor r_values = r._values(); // Sigh... needed because div_out takes Tensor&
249 at::div_out(r_values, t_tmp._values(), value, rounding_mode);
250 get_sparse_impl(r)->set_nnz_and_narrow(t_tmp._nnz());
251 r._coalesced_(t_tmp.is_coalesced());
252 }
253 return r;
254 }
255
div_out_sparse_zerodim(const SparseTensor & t,const Tensor & value,SparseTensor & r)256 SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, SparseTensor& r) {
257 return div_out_sparse_zerodim(t, value, /*rounding_mode=*/std::nullopt, r);
258 }
259
div_sparse(const Tensor & self,const Tensor & value)260 Tensor div_sparse(const Tensor& self, const Tensor& value) {
261 auto commonDtype = at::result_type(self, value);
262 if (c10::isIntegralType(commonDtype, /*includeBool=*/true)) {
263 commonDtype = typeMetaToScalarType(at::get_default_dtype());
264 }
265 Tensor result = at::empty({0}, self.options().dtype(commonDtype));
266 return div_out_sparse_zerodim(self, value, result);
267 }
268
div_sparse_(Tensor & self,const Tensor & value)269 Tensor& div_sparse_(Tensor& self, const Tensor& value) {
270 return div_out_sparse_zerodim(self, value, self);
271 }
272
div_sparse(const Tensor & self,const Tensor & value,std::optional<c10::string_view> rounding_mode)273 Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional<c10::string_view> rounding_mode) {
274 auto commonDtype = at::result_type(self, value);
275 if (c10::isIntegralType(commonDtype, /*includeBool=*/true) && !rounding_mode.has_value()) {
276 commonDtype = typeMetaToScalarType(at::get_default_dtype());
277 }
278 Tensor result = at::empty({0}, self.options().dtype(commonDtype));
279 return div_out_sparse_zerodim(self, value, std::move(rounding_mode), result);
280 }
281
div_sparse_(Tensor & self,const Tensor & value,std::optional<c10::string_view> rounding_mode)282 Tensor& div_sparse_(Tensor& self, const Tensor& value, std::optional<c10::string_view> rounding_mode) {
283 return div_out_sparse_zerodim(self, value, std::move(rounding_mode), self);
284 }
285
286 // --------------------------------------------------------------------
287 // floor_divide(SparseTensor, Scalar)
288 // --------------------------------------------------------------------
289
floor_divide_out_sparse_zerodim(const SparseTensor & dividend,const Tensor & divisor,SparseTensor & result)290 SparseTensor& floor_divide_out_sparse_zerodim(const SparseTensor& dividend,
291 const Tensor& divisor,
292 SparseTensor& result) {
293 TORCH_CHECK(divisor.dim() == 0, "Sparse floor division requires a scalar or ",
294 "zero-dim dense tensor divisor (got shape ", divisor.sizes(), " for divisor)");
295 TORCH_CHECK(!divisor.is_sparse(), "Sparse floor division requires a scalar or ",
296 "zero-dim dense tensor divisor (got a sparse divisor)");
297
298 AT_ASSERT(result.is_sparse());
299 AT_ASSERT(dividend.is_sparse());
300
301 // Case 1: result and dividend are the same tensor
302 // Performs floor division in-place
303 if (is_same_tensor(result, dividend)) {
304
305 // See note "Sparse Floor Division"
306 if (!result.is_coalesced()) {
307 coalesce_(result);
308 }
309
310 result._values().floor_divide_(divisor);
311 return result;
312 }
313
314 // Case 2: result and dividend are different tensors
315 Tensor dividend_tmp = dividend;
316
317 // Ensures dividend_tmp is coalesced (see note above)
318 if (!dividend.is_coalesced()) {
319 dividend_tmp = dividend.coalesce();
320 }
321
322 // Resizes and indexes result like dividend_tmp
323 result.resize_as_(dividend_tmp);
324 result._indices().resize_as_(dividend_tmp._indices());
325 result._indices().copy_(dividend_tmp._indices());
326
327 // Computes result
328 Tensor result_values = result._values();
329 at::floor_divide_out(result_values, dividend_tmp._values(), divisor);
330 get_sparse_impl(result)->set_nnz_and_narrow(dividend_tmp._nnz());
331 result._coalesced_(dividend_tmp.is_coalesced());
332 return result;
333 }
334
floor_divide_sparse(const Tensor & self,const Tensor & value)335 Tensor floor_divide_sparse(const Tensor& self, const Tensor& value) {
336 auto commonDtype = at::result_type(self, value);
337 Tensor result = at::empty({0}, self.options().dtype(commonDtype));
338 return floor_divide_out_sparse_zerodim(self, value, result);
339 }
340
floor_divide_sparse_(Tensor & self,const Tensor & value)341 Tensor& floor_divide_sparse_(Tensor& self, const Tensor& value) {
342 return floor_divide_out_sparse_zerodim(self, value, self);
343 }
344
345 // --------------------------------------------------------------------
346 // norm(SparseTensor, Scalar)
347 // --------------------------------------------------------------------
348
349 // Only supports floating point, FYI
norm_sparse(const SparseTensor & self,const Scalar & p)350 Tensor norm_sparse(const SparseTensor& self, const Scalar& p) {
351 AT_ASSERT(self.is_sparse());
352 return norm_sparse(self, p, IntArrayRef{}, false, std::nullopt);
353 }
354
norm_sparse(const SparseTensor & self,const std::optional<Scalar> & p,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)355 Tensor norm_sparse(const SparseTensor& self, const std::optional<Scalar>& p, IntArrayRef dim, bool keepdim, std::optional<ScalarType> dtype) {
356 AT_ASSERT(self.is_sparse());
357 if (!dim.empty()) {
358 // Only full reductions are supported, so check if that is the case
359 int64_t ndim = self.dim();
360 bool passed_full_reduction_check = static_cast<size_t>(ndim) == dim.size();
361 if (passed_full_reduction_check) {
362 auto dim_ = dim.vec();
363 maybe_wrap_dims(dim_, ndim);
364 std::vector<bool> dims_check(ndim, false);
365 // Need to check for duplicates, and fail if any are found
366 for (auto dim_ind : dim_) {
367 if (dims_check[dim_ind]) {
368 passed_full_reduction_check = false;
369 break;
370 }
371 dims_check[dim_ind] = true;
372 }
373 }
374 TORCH_CHECK(passed_full_reduction_check,
375 "norm_sparse currently only supports full reductions, so 'dim' must either be empty or contain all dimensions of the input");
376 }
377 TORCH_CHECK(keepdim == false, "norm_sparse currently does not support keepdim=True");
378 TORCH_CHECK(!dtype.has_value(), "norm_sparse currently does not support 'dtype' argument");
379 constexpr auto TWO = 2.0;
380 auto p_ = p.value_or(TWO);
381 return self.coalesce()._values().norm(p_);
382 }
383
384 // --------------------------------------------------------------------
385 // mv(SparseTensor, Tensor)
386 // --------------------------------------------------------------------
387
mv_sparse(const SparseTensor & self,const Tensor & vec)388 Tensor mv_sparse(const SparseTensor& self, const Tensor& vec)
389 {
390 TORCH_CHECK(self.ndimension() == 2 &&
391 vec.ndimension() == 1,
392 "mv: two tensor dim should be 2 and 1, but got ",
393 "SparseTensor Dim: ", self.ndimension(), "Tensor Dim: ", vec.ndimension());
394
395 TORCH_CHECK(vec.size(-1) == self.size(-1),
396 "mv: expected self.size(-1) == vec.size(-1)");
397
398 auto result = self.matmul(vec.unsqueeze(-1));
399
400 return result.squeeze(-1);
401 }
402
403 // --------------------------------------------------------------------
404 // add(SparseTensor, SparseTensor, Scalar) [broadcasts]
405 // --------------------------------------------------------------------
406
add_sparse(const Tensor & self,const Tensor & other,const Scalar & alpha)407 Tensor add_sparse(const Tensor& self, const Tensor& other, const Scalar& alpha) {
408 // TODO: Why?! Can't we just flip the order here...
409 TORCH_CHECK(!(self.is_sparse() && !other.is_sparse()),
410 "add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
411 auto commonDtype = at::result_type(self, other);
412 alpha_check(commonDtype, alpha);
413 Tensor result = at::empty({0}, self.options().dtype(commonDtype));
414 return at::add_out(result, self, other, alpha); // redispatch!
415 }
416
add_sparse_(Tensor & self,const Tensor & other,const Scalar & alpha)417 Tensor& add_sparse_(Tensor& self, const Tensor& other, const Scalar& alpha) {
418 return at::add_out(self, self, other, alpha); // redispatch!
419 }
420
421 // There's actually nothing sparse specific about these implementations
422
sub_sparse(const Tensor & self,const Tensor & other,const Scalar & alpha)423 Tensor sub_sparse(const Tensor& self, const Tensor& other, const Scalar& alpha) {
424 sub_check(self, other);
425 return native::add_sparse(self, other, -alpha);
426 }
427
sub_sparse_(Tensor & self,const Tensor & other,const Scalar & alpha)428 Tensor& sub_sparse_(Tensor& self, const Tensor& other, const Scalar& alpha) {
429 sub_check(self, other);
430 return native::add_sparse_(self, other, -alpha);
431 }
432
sub_out_sparse(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & r)433 Tensor& sub_out_sparse(const Tensor& self, const Tensor& other, const Scalar& alpha, Tensor& r) {
434 sub_check(self, other);
435 return at::add_out(r, self, other, -alpha); // redispatch!
436 }
437
438
add_out_sparse_contiguous(SparseTensor & r,const SparseTensor & t,const SparseTensor & src,const Scalar & value,ScalarType commonDtype)439 static SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, const Scalar& value, ScalarType commonDtype) {
440 // saving those because they can be overwritten when doing in-place operations
441 int64_t t_nnz = t._nnz(), s_nnz = src._nnz(), max_nnz = t_nnz + s_nnz;
442 bool coalesced = t.is_coalesced() && src.is_coalesced();
443 int64_t sparse_dim = src.sparse_dim();
444
445 Tensor r_indices = at::empty({src.sparse_dim(), max_nnz}, t._indices().options());
446
447 Tensor t_values = t._values().to(commonDtype);
448 Tensor s_values = src._values().to(commonDtype);
449
450 Tensor r_values = new_values_with_size_of(s_values, max_nnz).zero_();
451
452 int64_t blockSize = r_values.stride(0);
453 int64_t r_i = 0, t_i = 0, s_i = 0;
454 auto t_indices = t._indices();
455 auto src_indices = src._indices();
456
457 // NB: relies on nnz tests above
458 auto t_indices_accessor = t_indices.accessor<int64_t, 2>();
459 auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
460 auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
461
462 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
463 commonDtype, "cadd_sparse", [&] {
464 scalar_t* t_values_ptr = t_values.data_ptr<scalar_t>();
465 scalar_t* s_values_ptr = s_values.data_ptr<scalar_t>();
466 scalar_t* r_values_ptr = r_values.data_ptr<scalar_t>();
467 scalar_t cast_value = value.to<scalar_t>();
468 while (t_i < t_nnz || s_i < s_nnz) {
469 int64_t cmp;
470 if (t_i >= t_nnz) {
471 cmp = -1;
472 } else if (s_i >= s_nnz) {
473 cmp = 1;
474 } else {
475 cmp = 0;
476 for (auto d: c10::irange(sparse_dim)) {
477 if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) {
478 cmp = 1;
479 break;
480 }
481 if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) {
482 cmp = -1;
483 break;
484 }
485 }
486 }
487 if (cmp >= 0) {
488 for (auto d: c10::irange(sparse_dim)) {
489 r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i];
490 }
491 if (t_values.numel() > 0) { // We add all elements from t_values to r_values only if t_values is not an empty tensor
492 at::native::cpublas::axpy<scalar_t>(blockSize, 1,
493 t_values_ptr + t_i * blockSize, 1,
494 r_values_ptr + r_i * blockSize, 1);
495 }
496 t_i++;
497 }
498 if (cmp <= 0) {
499 for (auto d: c10::irange(sparse_dim)) {
500 r_indices_accessor[d][r_i] = src_indices_accessor[d][s_i];
501 }
502 if (s_values.numel() > 0) { // We add all elements from s_values to r_values only if s_values is not an empty tensor
503 at::native::cpublas::axpy<scalar_t>(blockSize, cast_value,
504 s_values_ptr + s_i * blockSize, 1,
505 r_values_ptr + r_i * blockSize, 1);
506 }
507 s_i++;
508 }
509 r_i++;
510 }
511 }
512 );
513
514 if (r.scalar_type() != commonDtype) {
515 r_values = r_values.to(r.scalar_type());
516 }
517 get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values);
518 get_sparse_impl(r)->set_nnz_and_narrow(r_i);
519
520 // TODO: I think it may be possible to track inside the loop and
521 // detect when we are uncoalesced (e.g., by observing that an
522 // index goes backwards) which may be more precise than using the
523 // coalesced flag here. But this is easy.
524 return r._coalesced_(coalesced);
525 }
526
add_out_sparse_non_contiguous(SparseTensor & r,const SparseTensor & t,const SparseTensor & src,const Scalar & value,ScalarType commonDtype)527 static SparseTensor& add_out_sparse_non_contiguous(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, const Scalar& value, ScalarType commonDtype) {
528 Tensor t_values = t._values().to(commonDtype);
529 Tensor s_values = src._values().to(commonDtype);
530
531 // If `t` or `src` contains non-contiguous `values`, `at::native::cpublas::axpy` doesn't work
532 // and we concat the indices and values tensors instead.
533 AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
534 commonDtype, "add_out_sparse_cpu", [&] {
535 if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
536 s_values = s_values.mul(value);
537 }
538 });
539
540 Tensor r_indices = at::cat({t._indices(), src._indices()}, 1);
541 Tensor r_values = at::cat({t_values, s_values}, 0).to(r.scalar_type());
542 alias_into_sparse(r, r_indices, r_values);
543
544 // Prevent unbounded growth of nnz
545 // TODO: Improved heuristic on when to coalesce or remove need to coalesce
546 if (r._nnz() > r.numel()) {
547 auto c = r.coalesce();
548 alias_into_sparse(r, c._indices(), c._values());
549 }
550
551 return r;
552 }
553
554 Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, const Scalar& value);
555
add_out_sparse_cpu(const SparseTensor & t,const SparseTensor & src,const Scalar & value,SparseTensor & r)556 SparseTensor& add_out_sparse_cpu(const SparseTensor& t, const SparseTensor& src, const Scalar& value, SparseTensor& r) {
557 if (!t.is_sparse()) {
558 return add_out_dense_sparse_cpu(r, t, src, value);
559 }
560 // TODO: This test seems a bit goofy
561 TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
562 AT_ASSERT(!t.is_cuda()); // the dispatch argument
563 TORCH_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor");
564 TORCH_CHECK(!src.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor");
565
566 TORCH_CHECK(t.sizes().equals(src.sizes()), "add: expected sizes of 'self' and 'other' to match, but ", t.sizes(), " != ", src.sizes());
567
568 auto commonDtype = promoteTypes(t.scalar_type(), src.scalar_type());
569
570 TORCH_CHECK(canCast(commonDtype, r.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r.scalar_type(), " in add operation");
571
572 if (src._nnz() == 0) {
573 return copy_sparse_to_sparse_(r, t);
574 }
575 if (t._nnz() == 0) {
576 return mul_out_sparse_scalar(r, src, value);
577 }
578
579 TORCH_CHECK(is_same_density(t, src), "add: expected 'self' and 'other' to have same density, but 'self' has ", t.sparse_dim(), " sparse dimensions while 'other' has ", src.sparse_dim(), " sparse dimensions");
580
581 r.resize_as_(src);
582 if (r.is_meta()) {
583 return r;
584 } else if (src._values().is_contiguous() && t._values().is_contiguous()) {
585 return add_out_sparse_contiguous(r, t, src, value, commonDtype);
586 } else {
587 return add_out_sparse_non_contiguous(r, t, src, value, commonDtype);
588 }
589 }
590
591 // --------------------------------------------------------------------
592 // add(Tensor, SparseTensor, Scalar)
593 // formerly known as spcadd
594 // --------------------------------------------------------------------
595 template <typename scalar_t>
add_dense_sparse_worker_non_hybrid_cpu(Tensor & r,const Scalar & value,const SparseTensor & sparse,const Tensor & indices,const Tensor & values)596 void add_dense_sparse_worker_non_hybrid_cpu(Tensor& r, const Scalar& value, const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
597 auto indices_accessor = indices.accessor<int64_t, 2>();
598 auto values_accessor = values.accessor<scalar_t, 1>();
599
600 scalar_t* r_ptr = r.data_ptr<scalar_t>();
601 scalar_t cast_value = value.to<scalar_t>();
602 const int64_t sparse_dim = sparse.sparse_dim();
603 std::vector<int64_t> result_stride(sparse_dim);
604 for (const auto d: c10::irange(sparse_dim)) {
605 result_stride[d] = r.stride(d);
606 }
607 at::parallel_for(0, sparse._nnz(), 0, [&](int64_t start, int64_t end) {
608 for (const auto k: c10::irange(start, end)) {
609 int64_t index = r.storage_offset();
610 for (auto d: c10::irange(sparse_dim)) {
611 index += result_stride[d] * indices_accessor[d][k];
612 }
613 r_ptr[index] += cast_value * values_accessor[k];
614 }
615 });
616 }
617
618 template <typename scalar_t>
add_dense_sparse_worker_hybrid_cpu(Tensor & r,const Scalar & value,const SparseTensor & sparse,const Tensor & indices,const Tensor & values)619 inline void add_dense_sparse_worker_hybrid_cpu(Tensor& r, const Scalar& value, const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
620
621 // Get the dense dimension element numbers of hybrid sparse tensor
622 int64_t values_dense_size = values.stride(0);
623 TORCH_CHECK(values.is_contiguous());
624 scalar_t* v_ptr = values.data_ptr<scalar_t>();
625
626 scalar_t* r_ptr = r.data_ptr<scalar_t>();
627 TORCH_CHECK(r_ptr != nullptr);
628
629 auto indices_accessor = indices.accessor<int64_t, 2>();
630 scalar_t cast_value = value.to<scalar_t>();
631 auto sparse_dim = sparse.sparse_dim();
632 std::vector<int64_t> result_stride(sparse_dim);
633 for (auto d : c10::irange(sparse_dim)) {
634 result_stride[d] = r.stride(d);
635 }
636
637 at::parallel_for(0, sparse._nnz(), 0, [&](int64_t start, int64_t end) {
638 for (auto k: c10::irange(start, end)) {
639 auto r_index = r_ptr;
640 for (auto d: c10::irange(sparse_dim)) {
641 r_index += result_stride[d] * indices_accessor[d][k];
642 }
643 auto v_index = v_ptr + k * values_dense_size;
644 at::native::cpublas::axpy<scalar_t>(values_dense_size, cast_value, v_index, 1, r_index, 1);
645 }
646 });
647 }
648
649 template <typename scalar_t>
add_dense_sparse_worker_non_coalesced_cpu(Tensor & r,const Scalar & value,const SparseTensor & sparse,const Tensor & indices,const Tensor & values)650 inline void add_dense_sparse_worker_non_coalesced_cpu(Tensor& r, const Scalar& value,
651 const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
652
653 // Get the dense dimension element numbers of hybrid sparse tensor
654 auto values_dense_size = values.stride(0);
655 TORCH_CHECK(values.is_contiguous());
656 scalar_t* v_ptr = values.data_ptr<scalar_t>();
657 TORCH_CHECK(v_ptr != nullptr);
658
659 scalar_t* r_ptr = r.data_ptr<scalar_t>();
660 TORCH_CHECK(r_ptr != nullptr);
661
662 scalar_t cast_value = value.to<scalar_t>();
663 auto sparse_dim = sparse.sparse_dim();
664
665 auto indices_accessor = indices.accessor<int64_t, 2>();
666 int64_t result_length = r.size(0);
667 std::vector<int64_t> result_stride(sparse_dim);
668 for (auto d : c10::irange(sparse_dim)) {
669 result_stride[d] = r.stride(d);
670 }
671
672 auto sparse_nnz = sparse._nnz();
673 int max_threads = at::get_num_threads();
674 max_threads = (result_length < max_threads) ? result_length : max_threads;
675 int64_t avg_chunk_down = result_length / max_threads;
676 std::vector<int64_t> chuck_size(max_threads);
677 for (const auto i : c10::irange(max_threads)) {
678 chuck_size[i] = avg_chunk_down;
679 }
680 //make chunk balance among threads as 211
681 for (auto i = 0 ; i < result_length % max_threads ; i++) {
682 chuck_size[i] += 1;
683 }
684 std::vector<int64_t> chuck_sum_size(max_threads + 1);
685 chuck_sum_size[0] = 0;
686 for (const auto i : c10::irange(1, max_threads)) {
687 chuck_sum_size[i] = chuck_sum_size[i - 1] + chuck_size[i - 1];
688 }
689 chuck_sum_size[max_threads] = result_length;
690 at::parallel_for(0, max_threads, 0, [&](int64_t start, int64_t end) {
691 for (auto k: c10::irange(start, end)) {
692 int64_t chunk_begin = chuck_sum_size[k];
693 int64_t chunk_end = chuck_sum_size[k + 1];
694 for (const auto n: c10::irange(sparse_nnz)) {
695 int64_t chunk_offset = indices_accessor[0][n];
696 if (chunk_offset >= chunk_begin && chunk_offset < chunk_end) {
697 int64_t r_offset = result_stride[0] * chunk_offset;
698 for (const auto d : c10::irange(1, sparse_dim)) {
699 r_offset += result_stride[d] * indices_accessor[d][n];
700 }
701 scalar_t* v_index = v_ptr + n * values_dense_size;
702 auto r_index = r_ptr + r_offset;
703 at::native::cpublas::axpy<scalar_t>(values_dense_size, cast_value, v_index, 1, r_index, 1);
704 }
705 }
706 }
707 });
708 }
709
add_out_dense_sparse_cpu(Tensor & r,const Tensor & dense,const SparseTensor & sparse_,const Scalar & value)710 Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, const Scalar& value) {
711 TORCH_CHECK(!r.is_sparse());
712 TORCH_CHECK(!dense.is_sparse());
713 TORCH_CHECK(sparse_.is_sparse());
714
715 TORCH_CHECK(!dense.is_cuda()); // dispatch argument
716 TORCH_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor");
717 TORCH_CHECK(!sparse_.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor");
718
719 TORCH_CHECK(dense.sizes().equals(sparse_.sizes()), "add: expected 'self' and 'other' to have same size, but self has size ",
720 dense.sizes(), " while other has size ", sparse_.sizes(), " (FYI: dense-sparse addition does not currently support broadcasting)");
721
722 auto commonDtype = promoteTypes(dense.scalar_type(), sparse_.scalar_type());
723 TORCH_CHECK(canCast(commonDtype, r.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r.scalar_type(), " in add operation");
724
725 r.resize_as_(dense);
726
727 auto sparse_nnz = sparse_._nnz();
728 if (sparse_nnz == 0) {
729 if (!is_same_tensor(r, dense)) r.copy_(dense);
730 return r;
731 }
732
733 int64_t dense_dim = dense.dim();
734 int64_t sparse_dim = sparse_.sparse_dim();
735 Tensor resultBuffer = r;
736 if (r.scalar_type() != commonDtype) {
737 resultBuffer = dense.to(commonDtype);
738 } else if (!is_same_tensor(r, dense)) {
739 resultBuffer.copy_(dense);
740 }
741
742 Tensor values = sparse_._values();
743 bool sparse_is_coalesced = (sparse_.is_coalesced() || sparse_nnz == 1);
744 bool result_is_contiguous = ((r.storage().data() != nullptr) && resultBuffer.is_contiguous());
745 bool value_is_contiguous = values.is_contiguous();
746 bool is_contiguous = (result_is_contiguous && value_is_contiguous);
747
748 SparseTensor sparse = sparse_;
749 Tensor indices = sparse_._indices();
750 Tensor valuesBuffer = values.to(commonDtype);
751 if (is_contiguous && sparse_is_coalesced) {
752 //TODO: we can optimize it for non-hybrid by not using buffers
753 if (sparse_dim == dense_dim) {
754 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
755 at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
756 commonDtype, "add_dense_sparse_non_hybrid", [&] {
757 add_dense_sparse_worker_non_hybrid_cpu<scalar_t>(resultBuffer, value, sparse_, indices, valuesBuffer);
758 });
759 } else {
760 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
761 at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
762 commonDtype, "add_dense_sparse_hybrid", [&] {
763 add_dense_sparse_worker_hybrid_cpu<scalar_t>(resultBuffer, value, sparse_, indices, valuesBuffer);
764 });
765 }
766 } else if (is_contiguous && (sparse_dim > 0)) {
767 // Handle sparse is not coalesced
768 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
769 at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
770 commonDtype, "add_dense_sparse_worker_non_coalesced", [&] {
771 add_dense_sparse_worker_non_coalesced_cpu<scalar_t>(resultBuffer, value, sparse_, indices, valuesBuffer);
772 });
773 } else {
774 // Slow path for non-contiguous values and output
775 // TODO: coalesce() performance may can be further improved
776 sparse = sparse_.coalesce();
777 indices = sparse._indices();
778 values = sparse._values();
779 valuesBuffer = values.to(commonDtype);
780 auto indices_accessor = indices.accessor<int64_t, 2>();
781 auto sparse_nnz = sparse._nnz();
782 at::parallel_for(0, sparse_nnz, 100, [&](int64_t start, int64_t end) {
783 for (auto k: c10::irange(start, end)) {
784 Tensor dstBuffer = resultBuffer;
785 for (auto d: c10::irange(sparse_dim)) {
786 dstBuffer = dstBuffer.select(0, indices_accessor[d][k]);
787 }
788 Tensor srcBuffer = valuesBuffer.select(0, k);
789 dstBuffer.add_(srcBuffer, value);
790 }
791 });
792 }
793 if (r.scalar_type() != commonDtype) {
794 r.copy_(resultBuffer);
795 }
796 return r;
797 }
798
799 // --------------------------------------------------------------------
800 // mul(SparseTensor, SparseTensor) [broadcasts]
801 // --------------------------------------------------------------------
802
mul_sparse(const Tensor & self,const Tensor & other)803 Tensor mul_sparse(const Tensor& self, const Tensor& other) {
804 auto commonDtype = at::result_type(self, other);
805 // Arbitrary (dense, sparse) and (sparse, dense) multiplication is not
806 // currently supported, but (0dim-dense, sparse) and (sparse, 0dim-dense) is.
807 // Make sure we use the sparse exemplar for result.
808 auto result_options = self.is_sparse() ?
809 self.options().dtype(commonDtype) : other.options().dtype(commonDtype);
810 Tensor result = at::empty({0}, result_options);
811 return at::mul_out(result, self, other); // redispatch!
812 }
813
mul_sparse_(Tensor & self,const Tensor & other)814 Tensor& mul_sparse_(Tensor& self, const Tensor& other) {
815 if (self.is_sparse()) {
816 return at::mul_out(self, self, other); // redispatch!
817 }
818 else {
819 const auto res = at::mul(self, other);
820 self.zero_();
821 self.add_(res);
822 return self;
823 }
824 }
825
826 // A generic function to implement pointwise-like operations
827 // with index intersection between dense and sparse COO tensors.
828 // NOTE: op is always called as op(dense_values, sparse_values),
829 // so it is up to the user to supply right implementations for non-commutative
830 // operations.
831 template <typename binary_func_t>
intersection_binary_op_sparse_dense_out(const Tensor & d,const SparseTensor & s_,Tensor & res,const char * const op_name,const binary_func_t & op,const bool coalesce=false)832 Tensor& intersection_binary_op_sparse_dense_out(
833 const Tensor& d,
834 const SparseTensor& s_,
835 Tensor& res,
836 const char* const op_name,
837 const binary_func_t& op,
838 const bool coalesce = false) {
839 // compute broadcasted shape.
840 const auto res_shape = infer_size(d.sizes(), s_.sizes());
841
842 // Short-circuit if either s_ or d is empty.
843 if (!s_._nnz() || !s_.numel() || !d.numel()) {
844 const int64_t dense_dim = s_.dense_dim();
845 const int64_t sparse_dim = static_cast<int64_t>(res_shape.size()) - dense_dim;
846 const int64_t nnz = 0;
847 const auto indices = at::empty({sparse_dim, nnz}, s_._indices().options());
848 auto res_values_shape = s_._values().sizes().vec();
849 res_values_shape[0] = nnz;
850 const auto values = at::empty(res_values_shape, s_._values().options().dtype(res.scalar_type()));
851 auto* res_impl = get_sparse_impl(res);
852 res_impl->raw_resize_(sparse_dim, dense_dim, /*size=*/res_shape);
853 res_impl->set_indices_and_values_unsafe(indices, values);
854 res_impl->set_nnz_and_narrow(nnz);
855 return res._coalesced_(true);
856 }
857
858 const auto d_dim = d.dim();
859 const auto s_dim = s_.dim();
860
861 // Always coalesce when sparse broadcasts over dense,
862 // because new sparse dimensions are created and
863 // repeated indices have to be eliminated because of that.
864 const auto s = (coalesce || d_dim > s_dim) ? s_.coalesce() : s_;
865
866 const auto sparse_dim = s.sparse_dim();
867 const auto dense_dim = s.dense_dim();
868
869 const auto s_indices = s._indices();
870 const auto s_values = s._values();
871
872 const auto apply_op = [&](const Tensor& d_filtered) -> Tensor& {
873 const auto res_indices = s_indices.clone();
874 // to(res.scalar_type) is only performed when both d and s are 0-dim.
875 // This insures right type promotions with the following rules:
876 // op(0-dim, 0-dim).dtype == <common dtype>
877 // op(0-dim, ge-1-dim).dtype == <ge-1-dim>.dtype,
878 // where ge-1-dim is a tensor with dim >= 1.
879 // We do not cast if op is performed in-place.
880 // The cast is required if s is 0-dim non-coalesced tensor and d is 0-dim.
881 // This is because s.values is at least 1D, so
882 // op(s.values, d).dtype == s.values.dtype, but we want
883 // op(s.values, d).dtype == <common dtype>.
884 const auto values = op(d_filtered, s_values);
885 const auto res_values = is_same_tensor(s_, res) ? values : values.to(res.scalar_type());
886 auto* res_impl = get_sparse_impl(res);
887 res_impl->raw_resize_(sparse_dim, dense_dim, res_shape);
888 res_impl->set_indices_and_values_unsafe(res_indices, res_values);
889 res_impl->set_nnz_and_narrow(s._nnz());
890 return res._coalesced_(s.is_coalesced());
891 };
892
893 // Easiest case: only dense dimensions intersect.
894 // This means only value tensors interact.
895 if (d_dim <= dense_dim) {
896 return apply_op(d);
897 }
898
899 // Now we have intersection between sparse and dense dims.
900 const auto sparse_dim_intersec = std::min(sparse_dim, d_dim - dense_dim);
901 const auto d_start_dim_intersec = std::max<int64_t>(0, d_dim - s_dim);
902 const auto s_start_dim_intersec = std::max<int64_t>(0, s_dim - d_dim);
903
904 // Index d with s_indices to find values which
905 // interact with s_values.
906 const auto d_filtered = [&]() -> Tensor {
907 using at::indexing::Slice;
908 using at::indexing::Ellipsis;
909 using at::indexing::TensorIndex;
910
911 std::vector<TensorIndex> intersec_indices;
912 intersec_indices.reserve(d_dim);
913
914 if (d_start_dim_intersec) {
915 intersec_indices.emplace_back(Ellipsis);
916 }
917 for (const auto i : c10::irange(sparse_dim_intersec)) {
918 const auto s_idx = s_start_dim_intersec + i;
919 intersec_indices.emplace_back(s_indices[s_idx]);
920 }
921 for (auto i = d_start_dim_intersec + sparse_dim_intersec; i < d_dim; ++i) {
922 intersec_indices.emplace_back(Slice());
923 }
924 // we need to expand d in the dimensions it is being indexed into
925 // to avoid out of bound indices
926 const auto d_expanded_shape = std::vector<int64_t>(
927 res_shape.end() - d_dim, res_shape.end());
928 return d.expand(d_expanded_shape).index(intersec_indices);
929 }();
930
931 // When dims match or sparse is "larger", the result nnz is the same,
932 // so only values get modified.
933 if (s_dim >= d_dim) {
934 return apply_op(d_filtered);
935 }
936
937 // Otherwise nnz gets larger, and both indices and values need an update.
938 const auto d_batch_shape = d.sizes().slice(0, d_start_dim_intersec);
939 const auto d_batch_len = static_cast<int64_t>(d_batch_shape.size());
940 int64_t batch_count = 1;
941 int64_t max_batch_dim = 0;
942 std::tie(batch_count, max_batch_dim) = [d_batch_shape]() -> std::tuple<int64_t, int64_t> {
943 int64_t batch_count = 1;
944 int64_t max_batch_dim = 0;
945 for (const auto& b : d_batch_shape) {
946 batch_count *= b;
947 max_batch_dim = std::max(b, max_batch_dim);
948 }
949 return std::make_tuple(batch_count, max_batch_dim);
950 }();
951
952 const auto res_sparse_dim = static_cast<int64_t>(d_batch_shape.size()) + sparse_dim;
953 const auto res_dense_dim = dense_dim;
954 const auto s_nnz = s._nnz();
955 const auto res_nnz = batch_count * s_nnz;
956 auto res_values_shape = s_values.sizes().vec();
957 res_values_shape[0] = res_nnz;
958 const auto res_values = op(d_filtered, s_values).reshape(res_values_shape);
959 const auto res_indices = [&]() -> Tensor {
960 const auto index_buffer = at::arange(max_batch_dim, s_indices.options());
961 auto indices = at::empty({res_sparse_dim, res_nnz}, s_indices.options());
962 // fill in indices corresponding to the "batch" dimensions of d.
963 int64_t n_repeat_interleave = res_nnz;
964 int64_t n_repeat = 1;
965 for (const auto dim : c10::irange(d_batch_len)) {
966 const auto dim_size = d_batch_shape[dim];
967 n_repeat_interleave /= dim_size;
968 // fill in indices corresponding to the "batch" dimension dim.
969 // Equivalent to indices[dim].copy_(repeat_interleave(dim_index, n_repeat_interleave).repeat(n_repeat))
970 const std::initializer_list<int64_t> dim_index_expanded_shape = {n_repeat, dim_size, n_repeat_interleave};
971 const auto dim_index = index_buffer.slice(-1, 0, dim_size);
972 const auto dim_index_expanded = dim_index.unsqueeze(0).unsqueeze_(-1).expand(dim_index_expanded_shape);
973 // NOTE: indices is contiguous, so view is safe
974 indices[dim].view(dim_index_expanded_shape).copy_(dim_index_expanded);
975 n_repeat *= dim_size;
976 }
977 // fill in indices corresponding to s_indices.
978 // Equivalent to indices_sparse.copy(s_indices.repeat({1, n_repeat})
979 n_repeat = res_nnz / s_nnz;
980 auto indices_sparse = indices.narrow(0, d_batch_len, res_sparse_dim - d_batch_len);
981 const std::initializer_list<int64_t> s_indices_expanded_shape = {-1, n_repeat, s_nnz};
982 const auto s_indices_expanded = s_indices.unsqueeze(1).expand(s_indices_expanded_shape);
983 indices_sparse.view(s_indices_expanded_shape).copy_(s_indices_expanded);
984
985 return indices;
986 }();
987 auto* res_impl = get_sparse_impl(res);
988 res_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
989 res_impl->set_indices_and_values_unsafe(res_indices, res_values);
990 res_impl->set_nnz_and_narrow(res_nnz);
991 // By design of index expansion and that s is coalesced,
992 // the result is also coalesced.
993 return res._coalesced_(true);
994 }
995
_mul_dense_sparse_out(const Tensor & d,const Tensor & s,Tensor & res)996 Tensor& _mul_dense_sparse_out(const Tensor& d, const Tensor& s, Tensor& res) {
997 return intersection_binary_op_sparse_dense_out(d, s, res, "mul", [](const Tensor& a, const Tensor& b) -> Tensor {
998 return at::mul(a, b);
999 });
1000 }
1001
_mul_sparse_sparse_zero_dim_out(const Tensor & zero_dim,const Tensor & other,Tensor & r)1002 Tensor& _mul_sparse_sparse_zero_dim_out(const Tensor& zero_dim, const Tensor& other, Tensor& r) {
1003 const auto is_wrapped_scalar = [](const Tensor& s) -> bool {
1004 return !s.dim() && s.is_coalesced();
1005 };
1006
1007 const auto extract_vals_from_wrapped_scalar = [](const Tensor& s) -> Tensor {
1008 auto vals = s._values().squeeze(0);
1009 // if squeeze does not kill the dim, it means that
1010 // vals is empty with shape [0]. In such a case we
1011 // return a 0-dim empty tensor to avoid broadcasting
1012 // issues in intersection_binary_op_sparse_dense_out
1013 // when the sparse argument is actually 0-dim.
1014 if (vals.dim()) {
1015 return at::empty({}, vals.options());
1016 }
1017 return vals;
1018 };
1019
1020 // The code dispatches to mul(dense, sparse), and the goal
1021 // is to delay calling into coalesce when converting one of
1022 // the sparse arguments to dense if possible.
1023 // This is possible when there is a 0-dim coalesced argument.
1024
1025 // if is_wrapped_scalar(zero_dim)
1026 if (zero_dim.is_coalesced()) {
1027 const auto scalar_val = extract_vals_from_wrapped_scalar(zero_dim);
1028 return _mul_dense_sparse_out(scalar_val, other, r);
1029 }
1030 // Here zero_dim is not a wrapped scalar, so we test other.
1031 if (is_wrapped_scalar(other)) {
1032 const auto scalar_val = extract_vals_from_wrapped_scalar(other);
1033 return _mul_dense_sparse_out(scalar_val, zero_dim, r);
1034 }
1035 // Neither of inputs is a wrapped scalar, but zero_dim
1036 // is at least 0-dim, so we coalesce it to convert to
1037 // a scalar.
1038 const auto scalar_val = extract_vals_from_wrapped_scalar(zero_dim.coalesce());
1039 return _mul_dense_sparse_out(scalar_val, other, r);
1040 }
1041
1042 DEFINE_DISPATCH(mul_sparse_sparse_out_stub);
1043
_mul_sparse_sparse_out(const Tensor & x,const Tensor & y,Tensor & res)1044 Tensor& _mul_sparse_sparse_out(const Tensor& x, const Tensor& y, Tensor& res) {
1045 mul_sparse_sparse_out_stub(res.device().type(), res, x, y);
1046 return res;
1047 }
1048
mul_out_sparse_cpu(const Tensor & t_,const Tensor & src_,Tensor & r)1049 SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, Tensor& r) {
1050 AT_ASSERT(!t_.is_cuda()); // dispatch argument
1051 TORCH_CHECK(!r.is_cuda(), "mul: expected 'out' to be CPU tensor, but got CUDA tensor");
1052 TORCH_CHECK(!src_.is_cuda(), "mul: expected 'other' to be a CPU tensor, but got a CUDA tensor");
1053 // case mul(sparse, dense)
1054 if (!src_.is_sparse()) {
1055 return _mul_dense_sparse_out(src_, t_, r);
1056 }
1057 // case mul(dense, sparse)
1058 if (!t_.is_sparse()) {
1059 return _mul_dense_sparse_out(t_, src_, r);
1060 }
1061
1062 // case mul(sparse, sparse) with a 0-dim input.
1063 if (!src_.dim()) {
1064 return _mul_sparse_sparse_zero_dim_out(src_, t_, r);
1065 }
1066 if (!t_.dim()) {
1067 return _mul_sparse_sparse_zero_dim_out(t_, src_, r);
1068 }
1069
1070 const auto is_equal_size_inputs = t_.sizes().equals(src_.sizes());
1071
1072 // mul(sparse, sparse) with inputs which broadcast only in dense dims
1073 if (!is_equal_size_inputs) {
1074 _mul_sparse_sparse_out(t_, src_, r);
1075 return r;
1076 }
1077
1078 TORCH_CHECK(is_equal_size_inputs, "mul: expected 'self' and 'other' to have same sizes when both are sparse"
1079 ", but ", t_.sizes(), " != ", src_.sizes());
1080
1081 // Short circuit when there is zero nnz
1082 // Not strictly necessary, but there are tests checking whether
1083 // resize in mul fails if run on tensors coming from .data/.detach.
1084 if (!t_._nnz() || !src_._nnz()) {
1085 r.resize_as_(t_);
1086 return r.zero_();
1087 }
1088
1089 // _mul_sparse_sparse_out is faster for large inputs
1090 // and when either of the inputs is uncoalesced.
1091 if (!t_.is_coalesced() || !src_.is_coalesced()) {
1092 _mul_sparse_sparse_out(t_, src_, r);
1093 return r;
1094 }
1095
1096 // Otherwise _mul_sparse_sparse_out might be slower
1097 // than the brute-force solution below.
1098
1099 SparseTensor t = t_.coalesce();
1100 SparseTensor src = src_.coalesce();
1101
1102 // saving those because they can be overwritten when doing in-place operations
1103 int64_t t_nnz = t._nnz(), s_nnz = src._nnz();
1104 int64_t max_nnz = std::min(t_nnz, s_nnz); // multiply by zero is zero, and can be dropped
1105 int64_t sparse_dim = src.sparse_dim();
1106 Tensor t_indices = t._indices();
1107 Tensor src_indices = src._indices();
1108 Tensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options());
1109
1110 int64_t r_i = 0, t_i = 0, s_i = 0;
1111
1112 auto commonDtype = promoteTypes(t_.scalar_type(), src_.scalar_type());
1113 TORCH_CHECK(canCast(commonDtype, r.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r.scalar_type(), " in mul operation");
1114
1115 Tensor t_values = t._values().to(commonDtype);
1116 Tensor s_values = src._values().to(commonDtype);
1117
1118 Tensor r_buffer = new_values_with_size_of(t_values, max_nnz).zero_();
1119
1120 // NB: relies on nnz test above
1121 auto t_indices_accessor = t_indices.accessor<int64_t, 2>();
1122 auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
1123 auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
1124
1125 // Check if we can find matching indices, and if so, write an
1126 // entry to the result indices vector. Returns true if matching
1127 // indices were found.
1128 auto index_preamble = [&]() {
1129 for (auto d: c10::irange(sparse_dim)) {
1130 if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) {
1131 t_i++;
1132 return false;
1133 }
1134 if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) {
1135 s_i++;
1136 return false;
1137 }
1138 }
1139 for (auto d: c10::irange(sparse_dim)) {
1140 r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i];
1141 }
1142 return true;
1143 };
1144
1145 if (t_values.dim() > 1) {
1146 while (t_i < t_nnz && s_i < s_nnz) {
1147 if (!index_preamble()) continue;
1148 r_buffer.select(0, r_i).addcmul_(t_values.select(0, t_i), s_values.select(0, s_i));
1149 r_i++;
1150 t_i++;
1151 s_i++;
1152 }
1153 } else {
1154 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
1155 at::ScalarType::ComplexHalf, at::ScalarType::BFloat16, at::ScalarType::Half,
1156 commonDtype, "mul_out_sparse", [&] {
1157 auto r_accessor = r_buffer.accessor<scalar_t, 1>();
1158 auto t_accessor = t_values.accessor<scalar_t, 1>();
1159 auto s_accessor = s_values.accessor<scalar_t, 1>();
1160
1161 while (t_i < t_nnz && s_i < s_nnz) {
1162 if (!index_preamble()) continue;
1163 r_accessor[r_i] = t_accessor[t_i] * s_accessor[s_i];
1164 r_i++;
1165 t_i++;
1166 s_i++;
1167 }
1168 }
1169 );
1170 }
1171
1172 r.resize_as_(src);
1173 Tensor r_values = r_buffer.to(r.scalar_type());
1174 get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values);
1175 get_sparse_impl(r)->set_nnz_and_narrow(r_i);
1176 return r._coalesced_(true);
1177 }
1178
1179 // --------------------------------------------------------------------
1180 // addmm(D1, S, D2, beta, alpha) -> D [broadcasts]
1181 //
1182 // D = beta * D1 + alpha * mm(S, D2)
1183 // --------------------------------------------------------------------
1184
1185 template <typename scalar_t>
s_addmm_out_sparse_dense_worker(int64_t nnz,int64_t dim_i,int64_t dim_j,int64_t dim_k,Tensor & r,const Scalar & beta,const Tensor & t,const Scalar & alpha,const Tensor & indices,const Tensor & values,const Tensor & dense)1186 void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& indices, const Tensor& values, const Tensor& dense) {
1187
1188 // r_ = alpha * sparse * dense
1189 scalar_t cast_alpha = alpha.to<scalar_t>();
1190 scalar_t cast_beta = beta.to<scalar_t>();
1191
1192 if (cast_beta == static_cast<scalar_t>(0)) {
1193 r.zero_();
1194 } else if (cast_beta == static_cast<scalar_t>(1)) {
1195 if (!is_same_tensor(r, t)) {
1196 r.copy_(t);
1197 }
1198 } else {
1199 at::mul_out(r, t, scalar_to_tensor(beta));
1200 }
1201
1202 auto indices_accessor = indices.accessor<int64_t, 2>();
1203
1204 auto values_accessor = values.accessor<scalar_t, 1>();
1205 scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
1206 scalar_t* r_ptr = r.data_ptr<scalar_t>();
1207
1208 int64_t dense_stride0 = dense.stride(0);
1209 int64_t dense_stride1 = dense.stride(1);
1210 int64_t r_stride0 = r.stride(0);
1211 int64_t r_stride1 = r.stride(1);
1212 for (auto i: c10::irange(nnz)) {
1213 scalar_t val = values_accessor[i];
1214 int64_t row = indices_accessor[0][i];
1215 int64_t col = indices_accessor[1][i];
1216 if (col >= 0 && col < dim_j && row >= 0 && row < dim_i) {
1217 // AXPY call is no-op over an empty vector
1218 if (dim_k == 0) {
1219 continue;
1220 }
1221 at::native::cpublas::axpy<scalar_t>(dim_k,
1222 cast_alpha * val,
1223 dense_ptr + col * dense_stride0, dense_stride1,
1224 r_ptr + row * r_stride0, r_stride1);
1225 } else {
1226 if (col < 0 || col >= dim_j) {
1227 AT_ERROR("addmm: index out of column bound: ", col, " not between 1 and ", dim_j);
1228 } else {
1229 AT_ERROR("addmm: index out of row bound: ", row, " not between 1 and ", dim_i);
1230 }
1231 }
1232 }
1233 };
1234
s_addmm_out_sparse_dense_cpu(Tensor & r,const Tensor & t,const SparseTensor & sparse_,const Tensor & dense,const Scalar & beta,const Scalar & alpha)1235 static Tensor& s_addmm_out_sparse_dense_cpu(
1236 Tensor& r,
1237 const Tensor& t,
1238 const SparseTensor& sparse_,
1239 const Tensor& dense,
1240 const Scalar& beta,
1241 const Scalar& alpha) {
1242 // TODO: This error message seems awfully opaque
1243 TORCH_CHECK(
1244 t.is_cpu(),
1245 "Expected all tensors to be on the same device. addmm expected 't' to be CPU tensor, but got tensor on ",
1246 t.device());
1247 TORCH_CHECK(
1248 r.is_cpu(),
1249 "Expected all tensors to be on the same device. addmm: expected 'out' to be CPU tensor, but got tensor on ",
1250 r.device());
1251 TORCH_CHECK(
1252 sparse_.is_cpu(),
1253 "Expected all tensors to be on the same device. addmm: expected 'mat1' to be a CPU tensor, but got tensor on ",
1254 sparse_.device());
1255 TORCH_CHECK(
1256 dense.is_cpu(),
1257 "Expected all tensors to be on the same device. addmm: expected 'mat2' to be a CPU tensor, but got tensor on ",
1258 dense.device());
1259
1260 TORCH_CHECK(
1261 r.layout() == kStrided,
1262 "addmm_sparse_dense: expected strided result tensor, got tensor with layout ",
1263 r.layout());
1264 TORCH_CHECK(
1265 t.layout() == kStrided,
1266 "addmm_sparse_dense: expected 't' to have strided layout, got tensor with layout ",
1267 t.layout());
1268 TORCH_CHECK(
1269 sparse_.layout() == kSparse && dense.layout() == kStrided,
1270 "addmm_sparse_dense: expected either 'mat1' to have sparse layout and 'mat2' to have strided layout, got 'mat1' with layout ",
1271 sparse_.layout(),
1272 " and 'mat2' with layout ",
1273 dense.layout());
1274
1275 TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: matrices expected, got ", sparse_.sparse_dim(), "D tensor");
1276 TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: scalar values expected, got ", sparse_.dense_dim(), "D values");
1277 TORCH_CHECK(dense.dim() == 2, "addmm: matrices expected, got ", dense.dim(), "D tensor");
1278
1279 // ixj * jxk = ixk
1280 int64_t dim_i = sparse_.size(0);
1281 int64_t dim_j = sparse_.size(1);
1282 int64_t dim_k = dense.size(1);
1283
1284 TORCH_CHECK(dense.size(0) == dim_j,
1285 "addmm: Argument #3 (dense): Expected dim 0 size ", dim_j, ", got ", dense.size(0));
1286 TORCH_CHECK(t.size(0) == dim_i,
1287 "addmm: Argument #1 (t): Expected dim 0 size ", dim_i, ", got ", t.size(0));
1288 TORCH_CHECK(t.size(1) == dim_k,
1289 "addmm: Argument #1 (t): Expected dim 1 size ", dim_k, ", got ", t.size(1));
1290
1291 r.resize_({dim_i, dim_k});
1292
1293 int64_t nnz = sparse_._nnz();
1294
1295 if (nnz == 0) {
1296 at::mul_out(r, t, at::scalar_tensor(beta, r.options()));
1297 return r;
1298 }
1299
1300 Tensor indices = sparse_._indices();
1301 Tensor values = sparse_._values();
1302
1303 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
1304 values.scalar_type(), "addmm_sparse_dense", [&] {
1305 s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense);
1306 }
1307 );
1308
1309 return r;
1310 }
1311
addmm_out_sparse_dense_cpu(const Tensor & self,const SparseTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)1312 Tensor& addmm_out_sparse_dense_cpu(
1313 const Tensor& self,
1314 const SparseTensor& mat1,
1315 const Tensor& mat2,
1316 const Scalar& beta,
1317 const Scalar& alpha,
1318 Tensor& result) {
1319 c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
1320 return s_addmm_out_sparse_dense_cpu(result, *b_self, mat1, mat2, beta, alpha);
1321 }
1322
s_addmm_sparse_dense_cpu(const Tensor & t,const SparseTensor & sparse,const Tensor & dense,const Scalar & beta,const Scalar & alpha)1323 static Tensor s_addmm_sparse_dense_cpu(
1324 const Tensor& t,
1325 const SparseTensor& sparse,
1326 const Tensor& dense,
1327 const Scalar& beta,
1328 const Scalar& alpha
1329 ) {
1330 Tensor r = at::empty({0}, t.options());
1331 s_addmm_out_sparse_dense_cpu(r, t, sparse, dense, beta, alpha);
1332 return r;
1333 }
1334
addmm_sparse_dense_cpu(const Tensor & self,const SparseTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)1335 Tensor addmm_sparse_dense_cpu(
1336 const Tensor& self,
1337 const SparseTensor& mat1,
1338 const Tensor& mat2,
1339 const Scalar& beta,
1340 const Scalar& alpha
1341 ) {
1342 c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
1343 return s_addmm_sparse_dense_cpu(*b_self, mat1, mat2, beta, alpha);
1344 }
1345
s_addmm_sparse_dense_cpu_(Tensor & t,const SparseTensor & sparse,const Tensor & dense,const Scalar & beta,const Scalar & alpha)1346 Tensor& s_addmm_sparse_dense_cpu_(
1347 Tensor& t,
1348 const SparseTensor& sparse,
1349 const Tensor& dense,
1350 const Scalar& beta,
1351 const Scalar& alpha
1352 ) {
1353 return s_addmm_out_sparse_dense_cpu(t, t, sparse, dense, beta, alpha);
1354 }
1355
1356 // NB: Purposely no broadcasting version of addmm inplace
1357
_sparse_addmm(const Tensor & t,const SparseTensor & sparse,const Tensor & dense,const Scalar & beta,const Scalar & alpha)1358 Tensor _sparse_addmm(
1359 const Tensor& t,
1360 const SparseTensor& sparse,
1361 const Tensor& dense,
1362 const Scalar& beta,
1363 const Scalar& alpha
1364 ) {
1365 // _sparse_addmm forward is functionally equivalent to addmm; it's
1366 // just the backward that is different. This technically does an
1367 // unnecessary redispatch, I was too lazy to make it not do that
1368 return at::addmm(t, sparse, dense, beta, alpha);
1369 }
1370
_sparse_mm(const Tensor & mat1,const Tensor & mat2)1371 Tensor _sparse_mm(
1372 const Tensor& mat1,
1373 const Tensor& mat2
1374 ) {
1375 if (mat1.is_sparse() && mat2.is_sparse()) {
1376 return at::_sparse_sparse_matmul(mat1, mat2);
1377 }
1378 if (mat1.is_sparse() || at::sparse_csr::is_sparse_compressed(mat1)) {
1379 Tensor t = at::zeros({mat1.size(-2), mat2.size(-1)}, mat2.options());
1380 return at::_sparse_addmm(t, mat1, mat2, 0, 1);
1381 }
1382 Tensor t = at::zeros({mat1.size(-2), mat2.size(-1)}, mat1.options());
1383 return at::_sparse_addmm(t.transpose(-2, -1), mat2.transpose(-2, -1), mat1.transpose(-2, -1), 0, 1).transpose(-2, -1);
1384 }
1385
1386 // NB: Despite its suggestive name, this actually only exists so that
1387 // we can redispatch to addmm_out; this is NOT an implementation of
1388 // the sparse masking version of mm
_sparse_mm_out(const SparseTensor & sparse,const Tensor & dense,SparseTensor & result)1389 SparseTensor& _sparse_mm_out(const SparseTensor& sparse,
1390 const Tensor& dense,
1391 SparseTensor& result) {
1392 Tensor t = at::zeros({}, dense.options());
1393 return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch!
1394 }
1395
_sparse_mm(const Tensor & mat1,const Tensor & mat2,const c10::string_view reduce)1396 Tensor _sparse_mm(const Tensor& mat1, const Tensor& mat2, const c10::string_view reduce) {
1397 // result: out, arg_out
1398 auto result = at::_sparse_mm_reduce_impl(mat1, mat2, reduce);
1399 return std::get<0>(result);
1400 }
1401
1402 // --------------------------------------------------------------------
1403 // hspmm(SparseTensor mat1, Tensor mat2)
1404 // --------------------------------------------------------------------
1405
hspmm_out_sparse_cpu(const SparseTensor & sparse_,const Tensor & dense,SparseTensor & r)1406 SparseTensor& hspmm_out_sparse_cpu(const SparseTensor& sparse_, const Tensor& dense, SparseTensor& r) {
1407 // TODO: Make this a real argument
1408 Scalar alpha = 1;
1409
1410 AT_ASSERT(!sparse_.is_cuda()); // dispatch argument
1411 TORCH_CHECK(!r.is_cuda(), "hspmm: expected 'out' to be CPU tensor, but got CUDA tensor");
1412 TORCH_CHECK(!dense.is_cuda(), "hspmm: expected 'other' to be a CPU tensor, but got a CUDA tensor");
1413
1414 TORCH_CHECK(sparse_.sparse_dim() == 2,
1415 "hspmm: Argument #2: matrices expected, got ", sparse_.sparse_dim(), "D tensor");
1416 TORCH_CHECK(sparse_.dense_dim() == 0,
1417 "hspmm: Argument #2: scalar values expected, got ", sparse_.dense_dim(), "D values");
1418 TORCH_CHECK(dense.dim() == 2,
1419 "hspmm: Argument #3: matrices expected, got ", dense.dim(), "D tensor");
1420
1421 int64_t m = sparse_.size(0);
1422 int64_t k = sparse_.size(1);
1423 int64_t n = dense.size(1);
1424
1425 TORCH_CHECK(dense.size(0) == k,
1426 "hspmm: Argument #3: Expected dim 0 size ", k, ", got ", dense.size(0));
1427
1428 get_sparse_impl(r)->raw_resize_(1, 1, {m, n});
1429
1430 SparseTensor sparse = sparse_.coalesce();
1431
1432 int64_t nnz = sparse._nnz();
1433
1434 if (nnz == 0) {
1435 r.zero_();
1436 return r;
1437 }
1438
1439 Tensor indices = at::empty({1, nnz}, at::initialTensorOptions().dtype(kLong));
1440
1441 // Initialize the sparse matrix that will be used with spaddmm to send rows
1442 // from the dense matrix to rows of the output's value tensor
1443 SparseTensor newSparse = sparse.clone();
1444 Tensor spIndices = newSparse._indices();
1445 Tensor valueIndices = spIndices.select(0, 0);
1446
1447 // Compute output indices
1448 auto valueIndices_accessor = valueIndices.accessor<int64_t, 1>();
1449 auto indices_accessor = indices.accessor<int64_t, 2>();
1450
1451 int64_t i = -1, prevIdx = -1;
1452 for (const auto j : c10::irange(nnz)) {
1453 int64_t currIdx = valueIndices_accessor[j];
1454 if (currIdx != prevIdx) {
1455 indices_accessor[0][++i] = currIdx;
1456 prevIdx = currIdx;
1457 }
1458 valueIndices_accessor[j] = i;
1459 }
1460 int64_t outNnz = i + 1;
1461 indices.resize_({1, outNnz});
1462 Tensor values = at::empty({outNnz, n}, dense.options());
1463
1464 std::vector<int64_t> new_size = get_sparse_impl(newSparse)->sizes().vec();
1465 new_size[0] = outNnz;
1466 get_sparse_impl(newSparse)->raw_resize_(get_sparse_impl(newSparse)->sparse_dim(), get_sparse_impl(newSparse)->dense_dim(), new_size);
1467
1468 // Compute output values tensor with sparse * dense multiplication
1469 s_addmm_out_sparse_dense_cpu(values, values, newSparse, dense, 0, alpha);
1470 get_sparse_impl(r)->set_indices_and_values_unsafe(indices, values);
1471
1472 return r;
1473 }
1474
hspmm_sparse_cpu(const SparseTensor & sparse,const Tensor & dense)1475 SparseTensor hspmm_sparse_cpu(const SparseTensor& sparse, const Tensor& dense) {
1476 SparseTensor r = at::empty({0}, sparse.options());
1477 hspmm_out_sparse_cpu(sparse, dense, r);
1478 return r;
1479 }
1480
1481 // --------------------------------------------------------------------
1482 // sspaddmm(S1, S2, D, beta, alpha) -> S
1483 //
1484 // S = beta * S1 + alpha * mm(S2, D)
1485 // --------------------------------------------------------------------
1486
_sspaddmm_out_cpu(const SparseTensor & t,const SparseTensor & sparse_,const Tensor & dense,const Scalar & beta,const Scalar & alpha,SparseTensor & r)1487 SparseTensor& _sspaddmm_out_cpu(
1488 const SparseTensor& t,
1489 const SparseTensor& sparse_,
1490 const Tensor& dense,
1491 const Scalar& beta,
1492 const Scalar& alpha,
1493 SparseTensor& r) {
1494 AT_ASSERT(!t.is_cuda()); // dispatch argument
1495 TORCH_CHECK(!r.is_cuda(), "sspaddmm: expected 'out' to be CPU tensor, but got CUDA tensor");
1496 TORCH_CHECK(!sparse_.is_cuda(), "sspaddmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
1497 TORCH_CHECK(!dense.is_cuda(), "sspaddmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
1498
1499 TORCH_CHECK(sparse_.sparse_dim() == 2,
1500 "sspaddmm: Argument #2: matrices expected, got ", sparse_.sparse_dim(), "D tensor");
1501 TORCH_CHECK(sparse_.dense_dim() == 0,
1502 "sspaddmm: Argument #2: scalar values expected, got ", sparse_.dense_dim(), "D values");
1503 TORCH_CHECK(dense.dim() == 2,
1504 "sspaddmm: Argument #2: matrices expected, got ", dense.dim(), "D tensor");
1505
1506 SparseTensor sparse = sparse_.coalesce();
1507
1508 // ixj * jxk = ixk
1509 int64_t dim_i = sparse.size(0);
1510 int64_t dim_j = sparse.size(1);
1511 int64_t dim_k = dense.size(1);
1512
1513 // NB: This has to occur before the checks, because r may alias t.
1514 // See test_saddmm
1515 get_sparse_impl(r)->raw_resize_(2, 0, {dim_i, dim_k});
1516
1517 TORCH_CHECK(dense.size(0) == dim_j,
1518 "sspaddmm: Argument #3: Expected dim 0 size ", dim_j, ", got ", dense.size(0));
1519 TORCH_CHECK(t.size(0) == dim_i,
1520 "sspaddmm: Argument #1: Expected dim 0 size ", dim_i, ", got ", t.size(0));
1521 TORCH_CHECK(t.size(1) == dim_k,
1522 "sspaddmm: Argument #1: Expected dim 1 size ", dim_k, ", got ", t.size(1));
1523
1524 int64_t nnz = sparse._nnz();
1525 // We have to make indices contiguous as we use indices.data_ptr in _to_csr which assumes row-contiguous storage
1526 Tensor indices = sparse._indices().contiguous();
1527 Tensor values = sparse._values();
1528
1529 Tensor csr = coo_to_csr(indices.data_ptr<int64_t>(), dim_i, nnz);
1530
1531 int64_t t_nnz = t._nnz();
1532 int64_t r_nnz = nnz * dim_k + t_nnz;
1533 Tensor newi = at::empty({2, r_nnz}, kLong);
1534 Tensor newv = at::zeros(
1535 {r_nnz},
1536 optTypeMetaToScalarType(values.options().dtype_opt()),
1537 values.options().layout_opt(),
1538 values.options().device_opt(),
1539 values.options().pinned_memory_opt());
1540
1541 if (t_nnz != 0) {
1542 Tensor narrowi = newi.narrow(1, 0, t_nnz);
1543 Tensor narrowv = newv.narrow(0, 0, t_nnz);
1544
1545 narrowi.copy_(t._indices());
1546 narrowv.copy_(t._values());
1547 newv.mul_(beta);
1548 }
1549
1550 // sparse = sparse * dense
1551 int64_t p = t_nnz;
1552
1553 auto csr_accessor = csr.accessor<int64_t, 1>();
1554 auto indices_accessor = indices.accessor<int64_t, 2>();
1555 auto newi_accessor = newi.accessor<int64_t, 2>();
1556
1557 int64_t dense_stride0 = dense.stride(0);
1558 int64_t dense_stride1 = dense.stride(1);
1559 int64_t newv_stride0 = newv.stride(0);
1560
1561 AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
1562 values.scalar_type(), "sspmm", [&] {
1563 auto values_accessor = values.accessor<scalar_t, 1>();
1564 scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
1565 scalar_t* newv_ptr = newv.data_ptr<scalar_t>();
1566 scalar_t cast_alpha = alpha.to<scalar_t>();
1567
1568 for (const auto h : c10::irange(dim_i)) {
1569 int64_t i_start = csr_accessor[h];
1570 int64_t i_end = csr_accessor[h+1];
1571 for (const auto i : c10::irange(i_start, i_end)) {
1572 scalar_t val = values_accessor[i];
1573 int64_t col = indices_accessor[1][i];
1574 if (col >= 0 && col < dim_j) {
1575 at::native::cpublas::axpy<scalar_t>(dim_k,
1576 cast_alpha * val,
1577 dense_ptr + col * dense_stride0, dense_stride1,
1578 newv_ptr + p * newv_stride0, 1);
1579 } else {
1580 AT_ERROR("index out of bound. sspmm: ", col, " not between 1 and ", dim_j);
1581 }
1582 }
1583 // Fill up the indices with the right values
1584 if (i_start != i_end) {
1585 for (const auto i : c10::irange(dim_k)) {
1586 newi_accessor[0][p+i] = h;
1587 newi_accessor[1][p+i] = i;
1588 }
1589 p += dim_k;
1590 }
1591 }
1592 }
1593 );
1594
1595 // to avoid a clone
1596 get_sparse_impl(r)->set_indices_and_values_unsafe(newi, newv);
1597 get_sparse_impl(r)->set_nnz_and_narrow(p);
1598
1599 return r;
1600 }
1601
1602 // sparse, sparse, sparse, dense, real, real -> sparse
_sspaddmm_out_only_sparse(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)1603 Tensor& _sspaddmm_out_only_sparse(const Tensor& self,
1604 const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Tensor& result) {
1605 AT_ERROR("tensor.sspaddmm(...) can only be called on sparse tensors");
1606 }
1607
1608 // sparse, dense -> sparse
smm(const Tensor & self,const Tensor & mat2)1609 Tensor smm(const Tensor& self, const Tensor& mat2) {
1610 auto result = at::empty({0}, self.options());
1611 at::sspaddmm_out(result, result, self, mat2, 0.0, 1.0);
1612 return result;
1613 }
1614
1615 // sparse, sparse, dense, real, real -> sparse
sspaddmm(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)1616 Tensor sspaddmm(const Tensor& self, const Tensor& mat1, const Tensor& mat2,
1617 const Scalar& beta, const Scalar& alpha) {
1618 auto result = at::empty({0}, self.options());
1619 at::sspaddmm_out(result, self, mat1, mat2, beta, alpha);
1620 return result;
1621 }
1622
1623 // --------------------------------------------------------------------
1624 // sparse.sum()
1625 //
1626 // This implementation calls coalesce() to do the sum reduction on
1627 // sparse dims. Ideally in the future there should be unified reduction function
1628 // for ops like sum, max, and min.
1629 // --------------------------------------------------------------------
_sparse_sum(const SparseTensor & input)1630 Tensor _sparse_sum(const SparseTensor& input) {
1631 return input.coalesce().values().sum();
1632 }
1633
_sparse_sum(const SparseTensor & input,ScalarType dtype)1634 Tensor _sparse_sum(const SparseTensor& input, ScalarType dtype) {
1635 // don't have to do a conversion to the correct dtype first
1636 // just need to setup the accumulator correctly
1637 return input.coalesce().values().sum(dtype);
1638 }
1639
_sparse_sum(const SparseTensor & input,IntArrayRef dims_to_sum,ScalarType dtype)1640 Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum, ScalarType dtype) {
1641 return at::_sparse_sum(input.to(dtype), dims_to_sum);
1642 }
1643
_sparse_sum(const SparseTensor & input,IntArrayRef dims_to_sum)1644 Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
1645 const int64_t input_dim = input.dim();
1646 auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
1647 auto dims_to_sum_v = dims_to_sum.vec();
1648 maybe_wrap_dims(dims_to_sum_v, input_dim);
1649
1650 Tensor indices = input._indices();
1651 Tensor values = input._values();
1652 IntArrayRef sizes = input.sizes();
1653 const int64_t sparse_dim = input.sparse_dim();
1654
1655 auto dims_to_keep_v = std::vector<int64_t>();
1656 auto dense_dims_to_sum_v = std::vector<int64_t>();
1657 for (const auto d : c10::irange(input_dim)) {
1658 if (dims_to_sum_b[d]) {
1659 if (d >= sparse_dim) dense_dims_to_sum_v.emplace_back(d + 1 - sparse_dim);
1660 }
1661 else {
1662 dims_to_keep_v.emplace_back(d);
1663 }
1664 }
1665 const int64_t sparse_dims_to_sum_size = dims_to_sum_v.size() - dense_dims_to_sum_v.size();
1666 const bool sum_all_sparse_dim = (sparse_dim == sparse_dims_to_sum_size);
1667 const bool sum_dense_dim = (!dense_dims_to_sum_v.empty());
1668
1669 // new values
1670 Tensor new_values;
1671 if (sum_dense_dim) {
1672 new_values = values.sum(dense_dims_to_sum_v);
1673 }
1674 else {
1675 new_values = values.clone(at::MemoryFormat::Contiguous);
1676 }
1677
1678 if (sum_all_sparse_dim) {
1679 // return a dense tensor if sum over all sparse dims
1680 new_values = new_values.sum(0);
1681 return new_values;
1682 }
1683 else { // !sum_all_sparse_dim
1684 // new indices
1685 Tensor new_indices;
1686 if (sparse_dims_to_sum_size == 0) {
1687 new_indices = indices.clone(at::MemoryFormat::Contiguous);
1688 }
1689 else {
1690 new_indices = at::empty({sparse_dim - sparse_dims_to_sum_size, input._nnz()}, indices.options());
1691 for (auto i: c10::irange(dims_to_keep_v.size())) {
1692 int64_t d = dims_to_keep_v[i];
1693 if (d < sparse_dim) new_indices[i].copy_(indices[d]);
1694 else break;
1695 }
1696 }
1697
1698 // new size
1699 int64_t new_sparse_dim = new_indices.size(0);
1700 int64_t new_dense_dim = new_values.dim() - 1; // exclude nnz dim
1701 std::vector<int64_t> new_sizes;
1702 new_sizes.reserve(dims_to_keep_v.size());
1703 for (auto d : dims_to_keep_v) new_sizes.emplace_back(sizes[d]);
1704 if (sum_all_sparse_dim) new_sizes.emplace(new_sizes.begin(), 1);
1705
1706 // use coalesce() to do sum reduction
1707 bool is_coalesced = false; // TODO: can we use input.is_coalesced()?
1708 SparseTensor new_sparse = at::_sparse_coo_tensor_with_dims_and_tensors(new_sparse_dim, new_dense_dim, new_sizes, new_indices, new_values, input.options(), is_coalesced);
1709 new_sparse = new_sparse.coalesce();
1710 return new_sparse;
1711 }
1712
1713 }
1714 // --------------------------------------------------------------------
1715 // NOTE [ sparse.sum() backward ]
1716 //
1717 // When sum over sparse_dim, backward scatters gradients from grad tensor to input tensor.
1718 // Grad and input need to align indices over sparse_dim that are not summed (given
1719 // input.spares_dim >= grad.sparse_dim). Implementation here compares each pair of
1720 // indices between grad and input. When a matching indices pair (input_i, grad_i) is found,
1721 // copy grad.values[grad_i] -> input_grad.values[input_i]. E.g.,
1722 //
1723 // input.sparse_dim = [5, 5]
1724 // input.indices = [[0, 0, 1, 2, 2, 3, 4, 4],
1725 // [1, 4, 4, 0, 1, 3, 2, 4]]
1726 // input.values = [0, 1, 2, 3, 4, 5, 6, 7]
1727 // ...
1728 // sparse.sum(input, [0])
1729 // backward(...)
1730 // ...
1731 // grad.indices = [[0, 1, 2, 3]]
1732 // grad.values = [1, 2, 0, 4]
1733 //
1734 // # after indices matching
1735 // input grad
1736 // [[0, 1], -> [1]
1737 // [0, 4], -> [ ]
1738 // [1, 4], -> [ ]
1739 // [2, 0], -> [0]
1740 // [2, 1], -> [1]
1741 // [3, 3], -> [3]
1742 // [4, 2], -> [2]
1743 // [4, 4]]) -> [ ]
1744 //
1745 // input_grad.indices = [[0, 0, 1, 2, 2, 3, 4, 4],
1746 // [1, 4, 4, 0, 1, 3, 2, 4]]
1747 // input_grad.values = [2, 0, 0, 1, 2, 4, 0, 0]
1748 //
1749 // Note that we allow input to be uncoalesced in the forward,
1750 // we have to coalesce input at the backward, because grad-of-input
1751 // take the same indices as input, if input is not coalesced, then
1752 // coalescing grad-of-input may add up grad values for a duplicate indices,
1753 // and hence generates a wrong grad-of-input.
1754 //
1755 // Other edge cases:
1756 // - assign zero values to input gradients if cannot find matched indices at grad
1757 // - grad.values might have zeros
1758 // --------------------------------------------------------------------
_sparse_sum_backward_cpu(const Tensor & grad_,const SparseTensor & input_,IntArrayRef dims_to_sum)1759 Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_, IntArrayRef dims_to_sum) {
1760 TORCH_CHECK(!grad_.is_cuda(), "_sparse_sum_backward_cpu: expected 'grad_' to be CPU tensor, but got CUDA tensor");
1761 TORCH_CHECK(!input_.is_cuda(), "_sparse_sum_backward_cpu: expected 'input_' to be CPU tensor, but got CUDA tensor");
1762
1763 // Short circuit if grad is either zero or empty.
1764 if (((grad_.is_sparse() || at::sparse_csr::is_sparse_compressed(grad_)) && !grad_._nnz()) || !grad_.numel()) {
1765 return at::zeros_like(input_);
1766 }
1767
1768 auto input = input_.coalesce();
1769 const int64_t input_dim = input.dim();
1770 auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
1771 auto dims_to_sum_v = dims_to_sum.vec();
1772 maybe_wrap_dims(dims_to_sum_v, input_dim);
1773
1774 Tensor input_indices = input._indices();
1775 Tensor input_values = input._values();
1776 IntArrayRef input_sizes = input.sizes();
1777 const int64_t input_sparse_dim = input.sparse_dim();
1778 const int64_t input_dense_dim = input.dense_dim();
1779 const int64_t input_nnz = input._nnz();
1780
1781 int64_t sparse_dims_to_sum_size = 0;
1782 auto sparse_dims_to_keep_v = std::vector<int64_t>();
1783 auto dense_dims_to_sum_v = std::vector<int64_t>();
1784 for (auto d: c10::irange(input_dim)) {
1785 if (dims_to_sum_b[d]) {
1786 if (d < input_sparse_dim) sparse_dims_to_sum_size ++;
1787 else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim);
1788 }
1789 else {
1790 if (d < input_sparse_dim) sparse_dims_to_keep_v.emplace_back(d);
1791 }
1792 }
1793
1794 const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size);
1795 const bool sum_dense_dim = (!dense_dims_to_sum_v.empty());
1796 const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0);
1797
1798 if (sum_all_sparse_dim) {
1799 TORCH_CHECK(!grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be dense since all sparse dims are summed");
1800 auto grad_input_values = grad_;
1801 auto expand_size = input_values.sizes().vec();
1802 if (sum_dense_dim) {
1803 auto dense_expand_size = std::vector<int64_t>(expand_size);
1804 dense_expand_size.erase(dense_expand_size.begin());
1805 AT_ASSERT(dense_expand_size.size() == static_cast<size_t>(input_values.dim() - 1));
1806 for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1); // -1 since grad has no nnz dim
1807 grad_input_values = grad_input_values.expand(dense_expand_size);
1808 }
1809 grad_input_values = grad_input_values.expand(expand_size).clone(at::MemoryFormat::Contiguous);
1810 bool grad_is_coalesced = input.is_coalesced();
1811 return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(at::MemoryFormat::Contiguous), grad_input_values, input.options().dtype(grad_.dtype()), grad_is_coalesced); // convert to grad dtype
1812 }
1813 else {
1814 TORCH_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be sparse, but got dense");
1815 auto grad = grad_.coalesce();
1816 Tensor grad_indices = grad._indices();
1817 Tensor grad_values = grad._values();
1818 const int64_t grad_sparse_dim = grad.sparse_dim();
1819 const int64_t grad_nnz = grad._nnz();
1820
1821 Tensor grad_values_expand = grad_values;
1822 if (sum_dense_dim) {
1823 auto expand_size = input_values.sizes().vec();
1824 if (sum_sparse_dim) expand_size[0] = grad_values.size(0);
1825 for (auto d : dense_dims_to_sum_v) grad_values_expand = grad_values_expand.unsqueeze(d);
1826 grad_values_expand = grad_values_expand.expand(expand_size).clone(at::MemoryFormat::Contiguous);
1827 }
1828
1829 Tensor grad_input_values;
1830 if (sum_sparse_dim) {
1831 // see NOTE [ sparse.sum() backward ]
1832 grad_input_values = at::zeros_like(input_values, grad_values.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1833
1834 // get flatten indices for grad and input
1835 auto grad_sparse_dim_to_keep_v = std::vector<int64_t>(grad_sparse_dim);
1836 std::iota(grad_sparse_dim_to_keep_v.begin(), grad_sparse_dim_to_keep_v.end(), 0);
1837
1838 auto grad_indices_1D = flatten_indices_by_dims(grad_indices, grad.sizes(), grad_sparse_dim_to_keep_v); // flatten indices on all sparse_dim of grad, output indices is coalesced and sorted
1839 auto grad_indices_1D_accessor = grad_indices_1D.accessor<int64_t, 1>();
1840 auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v);
1841 auto input_indices_1D_accessor = input_indices_1D.accessor<int64_t, 1>();
1842
1843 const auto copy_iter = TensorIteratorConfig()
1844 .add_output(grad_input_values)
1845 .add_input(grad_values_expand)
1846 .resize_outputs(false)
1847 .declare_static_shape(grad_values_expand.sizes(), /*squash_dims=*/0)
1848 .build();
1849 const auto device_type = kCPU;
1850
1851 const auto gIv_data = reinterpret_cast<char*>(grad_input_values.data_ptr());
1852 const auto gOv_data = reinterpret_cast<char*>(grad_values_expand.data_ptr());
1853 const auto gIv_stride = (grad_input_values.strides()[0] *
1854 grad_input_values.element_size());
1855 const auto gOv_stride = (grad_values_expand.strides()[0] *
1856 grad_values_expand.element_size());
1857
1858 // binary search to find matching indices
1859 at::parallel_for(0, input_nnz, 0, [&](int64_t start, int64_t end) {
1860 TensorIterator copy_iter_local(copy_iter);
1861
1862 for (auto i: c10::irange(start, end)) {
1863 int64_t input_idx = input_indices_1D_accessor[i];
1864 int64_t l = 0, r = grad_nnz - 1;
1865 while (l <= r) {
1866 int64_t m = l + (r - l) / 2;
1867 if (grad_indices_1D_accessor[m] == input_idx) {
1868 // grad_input_values[i].copy_(grad_values_expand[m])
1869 copy_iter_local.unsafe_replace_operand(0, gIv_data + i * gIv_stride);
1870 copy_iter_local.unsafe_replace_operand(1, gOv_data + m * gOv_stride);
1871 copy_stub(device_type, copy_iter_local, /*non_blocking=*/false);
1872 break;
1873 }
1874 if (grad_indices_1D_accessor[m] < input_idx) {
1875 l = m + 1;
1876 }
1877 else {
1878 r = m - 1;
1879 }
1880 }
1881 }
1882 });
1883 }
1884 else {
1885 grad_input_values = grad_values_expand;
1886 }
1887 bool grad_is_coalesced = input.is_coalesced();
1888 return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(at::MemoryFormat::Contiguous), grad_input_values, grad.options(), grad_is_coalesced);
1889 }
1890 }
1891
any_sparse(const Tensor & self)1892 Tensor any_sparse(const Tensor& self) {
1893 TORCH_INTERNAL_ASSERT(self.is_sparse());
1894
1895 return at::any(self._values());
1896 }
1897
bmm_sparse_cpu(const SparseTensor & self,const Tensor & mat2)1898 Tensor bmm_sparse_cpu(const SparseTensor& self, const Tensor& mat2) {
1899 Tensor result = at::empty({}, mat2.options());
1900 return bmm_out_sparse_cpu(self, mat2, result);
1901 }
1902
1903 // Search a sorted strided array for the rightmost instance of a value.
1904 // Array must be sorted from lowest to highest.
1905 // Returns the index of the found element.
1906 // Returns by reference `found`, true if search value was found, false otherwise
1907 template<typename scalar_t>
binary_search_strided_rightmost(scalar_t search_val,TensorAccessor<scalar_t,1> & sorted_arr_accessor,int64_t sorted_arr_begin_idx,int64_t length,bool * found)1908 scalar_t binary_search_strided_rightmost(scalar_t search_val, TensorAccessor<scalar_t, 1>& sorted_arr_accessor, int64_t sorted_arr_begin_idx, int64_t length, bool* found) {
1909 if (length == 0) {
1910 *found = false;
1911 return -1;
1912 }
1913
1914 int64_t left_ind = 0;
1915 int64_t right_ind = length - 1;
1916 // This value should be overwritten in the loop so we use
1917 // a destructive initial value to ensure disaster if that
1918 // turns out not to be the case.
1919 int64_t mid_ind = std::numeric_limits<int64_t>::max();
1920 bool done_searching = false;
1921
1922 while (!done_searching) {
1923 mid_ind = left_ind + (right_ind - left_ind) / 2;
1924 scalar_t mid_val = sorted_arr_accessor[sorted_arr_begin_idx + mid_ind];
1925
1926 if (mid_val > search_val) {
1927 right_ind = mid_ind-1;
1928 } else if((mid_val == search_val) && (
1929 (mid_ind == length - 1) || (sorted_arr_accessor[sorted_arr_begin_idx + mid_ind + 1] != search_val)
1930 )) {
1931 done_searching = true;
1932 *found = true;
1933 } else {
1934 left_ind = mid_ind+1;
1935 }
1936
1937 if (left_ind > right_ind) {
1938 done_searching = true;
1939 *found = false;
1940 mid_ind = -1;
1941 }
1942 }
1943
1944 return mid_ind;
1945 }
1946
bmm_out_sparse_cpu(const SparseTensor & self,const Tensor & mat2,Tensor & result)1947 Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor& result) {
1948 TORCH_CHECK(!mat2.is_sparse(), "bmm_sparse: Tensor 'mat2' must be dense");
1949
1950 TORCH_CHECK(self.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self.dense_dim());
1951 TORCH_CHECK(self.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self.sparse_dim());
1952 TORCH_CHECK(mat2.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2.dim());
1953
1954 TORCH_CHECK(self.size(0) == mat2.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match");
1955 TORCH_CHECK(self.size(2) == mat2.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match");
1956
1957 result.resize_({self.size(0), self.size(1), mat2.size(2)});
1958
1959 if (self._nnz() == 0) {
1960 result.zero_();
1961 return result;
1962 }
1963
1964 // First need to coalesce to get all of the first dimension indices
1965 // in order since we'll be sending each matrix into the MM operation
1966 SparseTensor self_coalesced = self.coalesce();
1967
1968 int64_t nnz = self_coalesced._nnz();
1969 Tensor indices = self_coalesced._indices();
1970 Tensor values = self_coalesced._values();
1971
1972 Tensor indices_dim0 = indices[0];
1973 auto indices_dim0_accessor = indices_dim0.accessor<int64_t, 1>();
1974 Tensor indices_dim1_dim2 = indices.slice(0, 1, 3);
1975
1976 int64_t dim_i = self_coalesced.size(1);
1977 int64_t dim_j = self_coalesced.size(2);
1978 int64_t dim_k = mat2.size(2);
1979
1980 Scalar beta = 0;
1981 Tensor t_dummy;
1982 Scalar alpha = 1;
1983
1984 int64_t mat_el_begin_idx = 0;
1985
1986 int64_t num_matrices = self_coalesced.size(0);
1987
1988 // Iterate through each set of 2D matrices within the 3D
1989 // tensor inputs, performing a matrix multiply with each one.
1990 int64_t start_mat_num = indices_dim0_accessor[0];
1991 AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
1992 values.scalar_type(), "bmm_sparse_dense", [&] {
1993 for (int64_t cur_mat_num = 0;
1994 (cur_mat_num < num_matrices);
1995 cur_mat_num++
1996 ) {
1997 // If there are sparse matrices at the beginning or end that
1998 // have all zero elements, we need to zero out the result matrix.
1999 if ((cur_mat_num < start_mat_num) || (mat_el_begin_idx >= nnz)) {
2000 result[cur_mat_num].zero_();
2001 continue;
2002 }
2003
2004 // Search for the range of sparse tensor elements that
2005 // correspond to the current matrix number. We already know
2006 // where the current matrix begins, so we just need to find
2007 // the end. The search excludes everything to the left of
2008 // the starting point, for best performance
2009 bool mat_end_found;
2010 int64_t mat_el_end_idx = binary_search_strided_rightmost(
2011 cur_mat_num,
2012 indices_dim0_accessor,
2013 mat_el_begin_idx,
2014 nnz-mat_el_begin_idx,
2015 &mat_end_found
2016 ) + mat_el_begin_idx;
2017
2018 if (mat_end_found) {
2019 mat_el_end_idx++;
2020
2021 // Create tensors to view just the current set of matrices
2022 const Tensor dense_matrix = mat2[cur_mat_num];
2023 Tensor result_matrix = result[cur_mat_num];
2024 Tensor sparse_indices = indices_dim1_dim2.slice(1, mat_el_begin_idx, mat_el_end_idx);
2025 Tensor sparse_values = values.slice(0, mat_el_begin_idx, mat_el_end_idx);
2026 int64_t sparse_nnz = mat_el_end_idx - mat_el_begin_idx;
2027
2028
2029 s_addmm_out_sparse_dense_worker<scalar_t>(
2030 sparse_nnz,
2031 dim_i, dim_j, dim_k,
2032 result_matrix,
2033 beta, t_dummy, alpha,
2034 sparse_indices, sparse_values,
2035 dense_matrix
2036 );
2037 mat_el_begin_idx = mat_el_end_idx;
2038
2039 // If no elements for this sparse matrix are found, then
2040 // it's a zero matrix and we need to zero out the result
2041 } else {
2042 result[cur_mat_num].zero_();
2043 }
2044 }
2045 }
2046 );
2047 return result;
2048 }
2049
conj_physical_out_sparse(const Tensor & input,Tensor & result)2050 Tensor& conj_physical_out_sparse(const Tensor& input, Tensor& result) {
2051 TORCH_INTERNAL_ASSERT(input.is_sparse());
2052 if (!is_same_tensor(result, input)) {
2053 copy_sparse_to_sparse_(result, input);
2054 }
2055 if (!input.is_complex()) {
2056 return result;
2057 }
2058 Tensor result_values = result._values();
2059 at::conj_physical_out(result_values, input._values());
2060 return result;
2061 }
2062
2063 } // namespace at::native
2064