xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseTensorMath.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/TensorIndexing.h>
3 #include <ATen/native/sparse/SparseTensorMath.h>
4 
5 #include <c10/util/irange.h>
6 #include <c10/util/MaybeOwned.h>
7 #include <ATen/core/Tensor.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/native/sparse/SparseStubs.h>
10 #include <ATen/Parallel.h>
11 #include <ATen/SparseCsrTensorUtils.h>
12 #include <ATen/SparseTensorImpl.h>
13 #include <ATen/ExpandUtils.h>
14 #include <ATen/ScalarOps.h>
15 #include <ATen/InitialTensorOptions.h>
16 #include <ATen/WrapDimUtilsMulti.h>
17 #include <ATen/native/BinaryOps.h>
18 #include <ATen/native/Copy.h>
19 #include <ATen/native/CPUBlas.h>
20 #include <ATen/native/SparseTensorUtils.h>
21 
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/_sparse_addmm.h>
27 #include <ATen/ops/_sparse_addmm_native.h>
28 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
29 #include <ATen/ops/_sparse_mm_native.h>
30 #include <ATen/ops/_sparse_sum.h>
31 #include <ATen/ops/_sparse_sum_backward_native.h>
32 #include <ATen/ops/_sparse_sum_native.h>
33 #include <ATen/ops/_sparse_sparse_matmul.h>
34 #include <ATen/ops/_sparse_mm_reduce_impl.h>
35 #include <ATen/ops/_sparse_mm_reduce_impl_native.h>
36 #include <ATen/ops/add.h>
37 #include <ATen/ops/add_native.h>
38 #include <ATen/ops/addmm.h>
39 #include <ATen/ops/addmm_native.h>
40 #include <ATen/ops/arange.h>
41 #include <ATen/ops/any.h>
42 #include <ATen/ops/any_native.h>
43 #include <ATen/ops/bmm_native.h>
44 #include <ATen/ops/cat.h>
45 #include <ATen/ops/conj_physical.h>
46 #include <ATen/ops/conj_physical_native.h>
47 #include <ATen/ops/copy_sparse_to_sparse.h>
48 #include <ATen/ops/div.h>
49 #include <ATen/ops/div_native.h>
50 #include <ATen/ops/empty.h>
51 #include <ATen/ops/empty_like.h>
52 #include <ATen/ops/floor_divide.h>
53 #include <ATen/ops/floor_divide_native.h>
54 #include <ATen/ops/hspmm_native.h>
55 #include <ATen/ops/mm_native.h>
56 #include <ATen/ops/mul.h>
57 #include <ATen/ops/mul_native.h>
58 #include <ATen/ops/mv_native.h>
59 #include <ATen/ops/native_norm_native.h>
60 #include <ATen/ops/neg_native.h>
61 #include <ATen/ops/pow.h>
62 #include <ATen/ops/pow_native.h>
63 #include <ATen/ops/result_type.h>
64 #include <ATen/ops/scalar_tensor.h>
65 #include <ATen/ops/smm_native.h>
66 #include <ATen/ops/sspaddmm.h>
67 #include <ATen/ops/sspaddmm_native.h>
68 #include <ATen/ops/sub_native.h>
69 #include <ATen/ops/zero_native.h>
70 #include <ATen/ops/zeros.h>
71 #include <ATen/ops/zeros_like.h>
72 #include <ATen/ops/zeros_native.h>
73 #include <ATen/ops/index.h>
74 #endif
75 
76 #include <algorithm>
77 
78 namespace at::native {
79 
80 using namespace at::sparse;
81 // --------------------------------------------------------------------
82 // zero_(SparseTensor)
83 // --------------------------------------------------------------------
84 
85 // hummu hummu
zero_sparse_(SparseTensor & self)86 SparseTensor& zero_sparse_(SparseTensor& self) {
87   AT_ASSERT(self.is_sparse());
88   self.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
89   return self._coalesced_(true);
90 }
91 
92 // NB: Don't need zeros, zeros_like, already implemented in TensorFactories
93 
94 // --------------------------------------------------------------------
95 // mul(SparseTensor, Scalar)
96 // --------------------------------------------------------------------
97 
mul_out_sparse_zerodim(SparseTensor & r,const SparseTensor & t,const Tensor & value)98 SparseTensor& mul_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) {
99   AT_ASSERT(r.is_sparse());
100   AT_ASSERT(t.is_sparse());
101   AT_ASSERT(value.dim() == 0);
102 
103   // Resolve a possibly sparse COO value to a strided tensor.
104   Tensor value_;
105   if (value.is_sparse()) {
106     if (value._nnz() == 0) {
107       r.resize_as_(t);
108       return r.zero_();
109     }
110     value_ = value.values();
111   } else {
112     value_ = value;
113   }
114   // With broadcasting in action, value_ may be a 1-D tensor as long
115   // as its shape is (1,).
116   AT_ASSERT(value_.numel() == 1);
117 
118   if (is_same_tensor(r, t)) {
119     r._values().mul_(value_);
120   } else {
121     r.resize_as_(t);
122     auto indices = r._indices();
123     indices.resize_as_(t._indices());
124     indices.copy_(t._indices());
125     Tensor r_values = r._values(); // Sigh... needed because mul_out takes Tensor&
126     at::mul_out(r_values, t._values(), value_);
127     get_sparse_impl(r)->set_nnz_and_narrow(t._nnz());
128     r._coalesced_(t.is_coalesced());
129   }
130   return r;
131 }
132 
mul_out_sparse_scalar(SparseTensor & r,const SparseTensor & t,const Scalar & value)133 SparseTensor& mul_out_sparse_scalar(SparseTensor& r, const SparseTensor& t, const Scalar& value) {
134   return mul_out_sparse_zerodim(r, t, wrapped_scalar_tensor(value));
135 }
136 
137 // --------------------------------------------------------------------
138 // neg(SparseTensor)
139 // --------------------------------------------------------------------
140 
neg_out_sparse(const SparseTensor & t,SparseTensor & r)141 SparseTensor& neg_out_sparse(const SparseTensor& t, SparseTensor& r) {
142   TORCH_CHECK(r.is_sparse(), "Tensor should be sparse");
143   TORCH_CHECK(t.is_sparse(), "Tensor should be sparse");
144 
145   // copy_sparse_ does not perform the copy if it is the same tensor
146   copy_sparse_to_sparse_(r, t);
147   r._values().neg_();
148   return r;
149 }
150 
neg_sparse(const SparseTensor & t)151 SparseTensor neg_sparse(const SparseTensor& t) {
152   SparseTensor r = at::empty_like(t);
153   neg_out_sparse(t, r);
154   return r;
155 }
156 
neg_sparse_(SparseTensor & t)157 SparseTensor& neg_sparse_(SparseTensor& t) {
158   return neg_out_sparse(t, t);
159 }
160 
161 // --------------------------------------------------------------------
162 // pow(SparseTensor, Scalar)
163 // --------------------------------------------------------------------
164 
165 // TODO: add in-place variant
166 
pow_out_sparse_scalar(const SparseTensor & t_,const Scalar & value,SparseTensor & r)167 SparseTensor& pow_out_sparse_scalar(const SparseTensor& t_, const Scalar& value, SparseTensor& r) {
168   AT_ASSERT(r.is_sparse());
169   AT_ASSERT(t_.is_sparse());
170   TORCH_CHECK(value.toDouble() != 0, "pow: cannot raise to zeroth power on sparse tensor; it would make the result tensor dense");
171 
172   // This coalesce is why we can't easily provide an inplace variant
173   SparseTensor t = t_.coalesce();
174 
175   r.resize_as_(t);
176   auto indices = r._indices();
177   indices.resize_as_(t._indices());
178   indices.copy_(t._indices());
179   Tensor r_values = r._values(); // Sigh... needed because pow_out takes Tensor&
180   at::pow_out(r_values, t._values(), value);
181   get_sparse_impl(r)->set_nnz_and_narrow(t._nnz());
182   return r._coalesced_(t.is_coalesced());
183 }
184 
pow_sparse_scalar(const SparseTensor & t,const Scalar & value)185 SparseTensor pow_sparse_scalar(const SparseTensor& t, const Scalar& value) {
186   SparseTensor r = at::empty({0}, t.options());
187   pow_out_sparse_scalar(t, value, r);
188   return r;
189 }
190 
191 // --------------------------------------------------------------------
192 // coalesce(SparseTensor)
193 // --------------------------------------------------------------------
194 
coalesce_(SparseTensor & tensor)195 static SparseTensor& coalesce_(SparseTensor& tensor) {
196   if (tensor.is_coalesced()) {
197     return tensor;
198   }
199 
200   SparseTensor coalesced = tensor.coalesce();
201   tensor._values().resize_as_(coalesced._values());
202   tensor._indices().resize_as_(coalesced._indices());
203   tensor._values().copy_(coalesced._values());
204   tensor._indices().copy_(coalesced._indices());
205   tensor._coalesced_(true);
206   return tensor;
207 }
208 
209 // Note [Sparse Floor Division]
210 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
211 // Uncoalesced sparse tensors cannot be floor divided correctly. Integer
212 // division is considered a special-case of floor division for purposes of
213 // this note.
214 // For example, an integer tensor with values=[3, 3] divided by 2 would produce
215 // values=[1, 1], which sum to 2 instead of 3 (=6/2).
216 // A float tensor with values=[3., 3.] floor divided by 2 would also produce
217 // values=[1., 1.] (after truncation), which sum to 2.f instead of 3.f.
218 // To perform floor division the sparse tensor must be coalesced first.
219 // --------------------------------------------------------------------
220 // div(SparseTensor, Scalar)
221 // --------------------------------------------------------------------
222 
div_out_sparse_zerodim(const SparseTensor & t,const Tensor & value,std::optional<c10::string_view> rounding_mode,SparseTensor & r)223 SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, std::optional<c10::string_view> rounding_mode, SparseTensor& r) {
224   TORCH_CHECK(value.dim() == 0, "Sparse division requires a scalar or ",
225     "zero-dim dense tensor divisor (got shape ", value.sizes(), " for divisor)");
226   TORCH_CHECK(!value.is_sparse(), "Sparse division requires a scalar or ",
227     "zero-dim dense tensor divisor (got a sparse divisor)");
228 
229   AT_ASSERT(r.is_sparse());
230   AT_ASSERT(t.is_sparse());
231 
232   // See note "Sparse Floor Division"
233   const bool should_coalesce = rounding_mode.has_value() && !t.is_coalesced();
234   if (is_same_tensor(r, t)) {
235     if (should_coalesce) {
236       coalesce_(r);
237     }
238     r._values().div_(value, rounding_mode);
239   } else {
240     Tensor t_tmp = t;
241     if (should_coalesce) {
242       t_tmp = t.coalesce();
243     }
244     r.resize_as_(t_tmp);
245     auto indices = r._indices();
246     indices.resize_as_(t_tmp._indices());
247     indices.copy_(t_tmp._indices());
248     Tensor r_values = r._values(); // Sigh... needed because div_out takes Tensor&
249     at::div_out(r_values, t_tmp._values(), value, rounding_mode);
250     get_sparse_impl(r)->set_nnz_and_narrow(t_tmp._nnz());
251     r._coalesced_(t_tmp.is_coalesced());
252   }
253   return r;
254 }
255 
div_out_sparse_zerodim(const SparseTensor & t,const Tensor & value,SparseTensor & r)256 SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value, SparseTensor& r) {
257   return div_out_sparse_zerodim(t, value, /*rounding_mode=*/std::nullopt, r);
258 }
259 
div_sparse(const Tensor & self,const Tensor & value)260 Tensor div_sparse(const Tensor& self, const Tensor& value) {
261   auto commonDtype = at::result_type(self, value);
262   if (c10::isIntegralType(commonDtype, /*includeBool=*/true)) {
263     commonDtype = typeMetaToScalarType(at::get_default_dtype());
264   }
265   Tensor result = at::empty({0}, self.options().dtype(commonDtype));
266   return div_out_sparse_zerodim(self, value, result);
267 }
268 
div_sparse_(Tensor & self,const Tensor & value)269 Tensor& div_sparse_(Tensor& self, const Tensor& value) {
270   return div_out_sparse_zerodim(self, value, self);
271 }
272 
div_sparse(const Tensor & self,const Tensor & value,std::optional<c10::string_view> rounding_mode)273 Tensor div_sparse(const Tensor& self, const Tensor& value, std::optional<c10::string_view> rounding_mode) {
274   auto commonDtype = at::result_type(self, value);
275   if (c10::isIntegralType(commonDtype, /*includeBool=*/true) && !rounding_mode.has_value()) {
276     commonDtype = typeMetaToScalarType(at::get_default_dtype());
277   }
278   Tensor result = at::empty({0}, self.options().dtype(commonDtype));
279   return div_out_sparse_zerodim(self, value, std::move(rounding_mode), result);
280 }
281 
div_sparse_(Tensor & self,const Tensor & value,std::optional<c10::string_view> rounding_mode)282 Tensor& div_sparse_(Tensor& self, const Tensor& value, std::optional<c10::string_view> rounding_mode) {
283   return div_out_sparse_zerodim(self, value, std::move(rounding_mode), self);
284 }
285 
286 // --------------------------------------------------------------------
287 // floor_divide(SparseTensor, Scalar)
288 // --------------------------------------------------------------------
289 
floor_divide_out_sparse_zerodim(const SparseTensor & dividend,const Tensor & divisor,SparseTensor & result)290 SparseTensor& floor_divide_out_sparse_zerodim(const SparseTensor& dividend,
291   const Tensor& divisor,
292   SparseTensor& result) {
293   TORCH_CHECK(divisor.dim() == 0, "Sparse floor division requires a scalar or ",
294     "zero-dim dense tensor divisor (got shape ", divisor.sizes(), " for divisor)");
295   TORCH_CHECK(!divisor.is_sparse(), "Sparse floor division requires a scalar or ",
296     "zero-dim dense tensor divisor (got a sparse divisor)");
297 
298   AT_ASSERT(result.is_sparse());
299   AT_ASSERT(dividend.is_sparse());
300 
301   // Case 1: result and dividend are the same tensor
302   // Performs floor division in-place
303   if (is_same_tensor(result, dividend)) {
304 
305     // See note "Sparse Floor Division"
306     if (!result.is_coalesced()) {
307       coalesce_(result);
308     }
309 
310     result._values().floor_divide_(divisor);
311     return result;
312   }
313 
314   // Case 2: result and dividend are different tensors
315   Tensor dividend_tmp = dividend;
316 
317   // Ensures dividend_tmp is coalesced (see note above)
318   if (!dividend.is_coalesced()) {
319     dividend_tmp = dividend.coalesce();
320   }
321 
322   // Resizes and indexes result like dividend_tmp
323   result.resize_as_(dividend_tmp);
324   result._indices().resize_as_(dividend_tmp._indices());
325   result._indices().copy_(dividend_tmp._indices());
326 
327   // Computes result
328   Tensor result_values = result._values();
329   at::floor_divide_out(result_values, dividend_tmp._values(), divisor);
330   get_sparse_impl(result)->set_nnz_and_narrow(dividend_tmp._nnz());
331   result._coalesced_(dividend_tmp.is_coalesced());
332   return result;
333 }
334 
floor_divide_sparse(const Tensor & self,const Tensor & value)335 Tensor floor_divide_sparse(const Tensor& self, const Tensor& value) {
336   auto commonDtype = at::result_type(self, value);
337   Tensor result = at::empty({0}, self.options().dtype(commonDtype));
338   return floor_divide_out_sparse_zerodim(self, value, result);
339 }
340 
floor_divide_sparse_(Tensor & self,const Tensor & value)341 Tensor& floor_divide_sparse_(Tensor& self, const Tensor& value) {
342   return floor_divide_out_sparse_zerodim(self, value, self);
343 }
344 
345 // --------------------------------------------------------------------
346 // norm(SparseTensor, Scalar)
347 // --------------------------------------------------------------------
348 
349 // Only supports floating point, FYI
norm_sparse(const SparseTensor & self,const Scalar & p)350 Tensor norm_sparse(const SparseTensor& self, const Scalar& p) {
351   AT_ASSERT(self.is_sparse());
352   return norm_sparse(self, p, IntArrayRef{}, false, std::nullopt);
353 }
354 
norm_sparse(const SparseTensor & self,const std::optional<Scalar> & p,IntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)355 Tensor norm_sparse(const SparseTensor& self, const std::optional<Scalar>& p, IntArrayRef dim, bool keepdim, std::optional<ScalarType> dtype) {
356   AT_ASSERT(self.is_sparse());
357   if (!dim.empty()) {
358     // Only full reductions are supported, so check if that is the case
359     int64_t ndim = self.dim();
360     bool passed_full_reduction_check = static_cast<size_t>(ndim) == dim.size();
361     if (passed_full_reduction_check) {
362       auto dim_ = dim.vec();
363       maybe_wrap_dims(dim_, ndim);
364       std::vector<bool> dims_check(ndim, false);
365       // Need to check for duplicates, and fail if any are found
366       for (auto dim_ind : dim_) {
367         if (dims_check[dim_ind]) {
368           passed_full_reduction_check = false;
369           break;
370         }
371         dims_check[dim_ind] = true;
372       }
373     }
374     TORCH_CHECK(passed_full_reduction_check,
375       "norm_sparse currently only supports full reductions, so 'dim' must either be empty or contain all dimensions of the input");
376   }
377   TORCH_CHECK(keepdim == false, "norm_sparse currently does not support keepdim=True");
378   TORCH_CHECK(!dtype.has_value(), "norm_sparse currently does not support 'dtype' argument");
379   constexpr auto TWO = 2.0;
380   auto p_ = p.value_or(TWO);
381   return self.coalesce()._values().norm(p_);
382 }
383 
384 // --------------------------------------------------------------------
385 // mv(SparseTensor, Tensor)
386 // --------------------------------------------------------------------
387 
mv_sparse(const SparseTensor & self,const Tensor & vec)388 Tensor mv_sparse(const SparseTensor& self, const Tensor& vec)
389 {
390   TORCH_CHECK(self.ndimension() == 2 &&
391               vec.ndimension() == 1,
392               "mv: two tensor dim should be 2 and 1, but got ",
393               "SparseTensor Dim: ", self.ndimension(), "Tensor Dim: ", vec.ndimension());
394 
395   TORCH_CHECK(vec.size(-1) == self.size(-1),
396               "mv: expected self.size(-1) == vec.size(-1)");
397 
398   auto result = self.matmul(vec.unsqueeze(-1));
399 
400   return result.squeeze(-1);
401 }
402 
403 // --------------------------------------------------------------------
404 // add(SparseTensor, SparseTensor, Scalar)  [broadcasts]
405 // --------------------------------------------------------------------
406 
add_sparse(const Tensor & self,const Tensor & other,const Scalar & alpha)407 Tensor add_sparse(const Tensor& self, const Tensor& other, const Scalar& alpha) {
408   // TODO: Why?! Can't we just flip the order here...
409   TORCH_CHECK(!(self.is_sparse() && !other.is_sparse()),
410               "add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
411   auto commonDtype = at::result_type(self, other);
412   alpha_check(commonDtype, alpha);
413   Tensor result = at::empty({0}, self.options().dtype(commonDtype));
414   return at::add_out(result, self, other, alpha);  // redispatch!
415 }
416 
add_sparse_(Tensor & self,const Tensor & other,const Scalar & alpha)417 Tensor& add_sparse_(Tensor& self, const Tensor& other, const Scalar& alpha) {
418   return at::add_out(self, self, other, alpha);  // redispatch!
419 }
420 
421 // There's actually nothing sparse specific about these implementations
422 
sub_sparse(const Tensor & self,const Tensor & other,const Scalar & alpha)423 Tensor sub_sparse(const Tensor& self, const Tensor& other, const Scalar& alpha) {
424   sub_check(self, other);
425   return native::add_sparse(self, other, -alpha);
426 }
427 
sub_sparse_(Tensor & self,const Tensor & other,const Scalar & alpha)428 Tensor& sub_sparse_(Tensor& self, const Tensor& other, const Scalar& alpha) {
429   sub_check(self, other);
430   return native::add_sparse_(self, other, -alpha);
431 }
432 
sub_out_sparse(const Tensor & self,const Tensor & other,const Scalar & alpha,Tensor & r)433 Tensor& sub_out_sparse(const Tensor& self, const Tensor& other, const Scalar& alpha, Tensor& r) {
434   sub_check(self, other);
435   return at::add_out(r, self, other, -alpha);  // redispatch!
436 }
437 
438 
add_out_sparse_contiguous(SparseTensor & r,const SparseTensor & t,const SparseTensor & src,const Scalar & value,ScalarType commonDtype)439 static SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, const Scalar& value, ScalarType commonDtype) {
440     // saving those because they can be overwritten when doing in-place operations
441     int64_t t_nnz = t._nnz(), s_nnz = src._nnz(), max_nnz = t_nnz + s_nnz;
442     bool coalesced = t.is_coalesced() && src.is_coalesced();
443     int64_t sparse_dim = src.sparse_dim();
444 
445     Tensor r_indices = at::empty({src.sparse_dim(), max_nnz}, t._indices().options());
446 
447     Tensor t_values = t._values().to(commonDtype);
448     Tensor s_values = src._values().to(commonDtype);
449 
450     Tensor r_values = new_values_with_size_of(s_values, max_nnz).zero_();
451 
452     int64_t blockSize = r_values.stride(0);
453     int64_t r_i = 0, t_i = 0, s_i = 0;
454     auto t_indices = t._indices();
455     auto src_indices = src._indices();
456 
457     // NB: relies on nnz tests above
458     auto t_indices_accessor = t_indices.accessor<int64_t, 2>();
459     auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
460     auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
461 
462     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
463         commonDtype, "cadd_sparse", [&] {
464           scalar_t* t_values_ptr = t_values.data_ptr<scalar_t>();
465           scalar_t* s_values_ptr = s_values.data_ptr<scalar_t>();
466           scalar_t* r_values_ptr = r_values.data_ptr<scalar_t>();
467           scalar_t cast_value = value.to<scalar_t>();
468           while (t_i < t_nnz || s_i < s_nnz) {
469             int64_t cmp;
470             if (t_i >= t_nnz) {
471               cmp = -1;
472             } else if (s_i >= s_nnz) {
473               cmp = 1;
474             } else {
475               cmp = 0;
476               for (auto d: c10::irange(sparse_dim)) {
477                 if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) {
478                   cmp = 1;
479                   break;
480                 }
481                 if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) {
482                   cmp = -1;
483                   break;
484                 }
485               }
486             }
487             if (cmp >= 0) {
488               for (auto d: c10::irange(sparse_dim)) {
489                 r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i];
490               }
491               if (t_values.numel() > 0) {  // We add all elements from t_values to r_values only if t_values is not an empty tensor
492                 at::native::cpublas::axpy<scalar_t>(blockSize, 1,
493                   t_values_ptr + t_i * blockSize, 1,
494                   r_values_ptr + r_i * blockSize, 1);
495               }
496               t_i++;
497             }
498             if (cmp <= 0) {
499               for (auto d: c10::irange(sparse_dim)) {
500                 r_indices_accessor[d][r_i] = src_indices_accessor[d][s_i];
501               }
502               if (s_values.numel() > 0) {  // We add all elements from s_values to r_values only if s_values is not an empty tensor
503                 at::native::cpublas::axpy<scalar_t>(blockSize, cast_value,
504                   s_values_ptr + s_i * blockSize, 1,
505                   r_values_ptr + r_i * blockSize, 1);
506               }
507               s_i++;
508             }
509             r_i++;
510           }
511         }
512     );
513 
514     if (r.scalar_type() != commonDtype) {
515       r_values = r_values.to(r.scalar_type());
516     }
517     get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values);
518     get_sparse_impl(r)->set_nnz_and_narrow(r_i);
519 
520     // TODO: I think it may be possible to track inside the loop and
521     // detect when we are uncoalesced (e.g., by observing that an
522     // index goes backwards) which may be more precise than using the
523     // coalesced flag here.  But this is easy.
524     return r._coalesced_(coalesced);
525 }
526 
add_out_sparse_non_contiguous(SparseTensor & r,const SparseTensor & t,const SparseTensor & src,const Scalar & value,ScalarType commonDtype)527 static SparseTensor& add_out_sparse_non_contiguous(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, const Scalar& value, ScalarType commonDtype) {
528     Tensor t_values = t._values().to(commonDtype);
529     Tensor s_values = src._values().to(commonDtype);
530 
531     // If `t` or `src` contains non-contiguous `values`, `at::native::cpublas::axpy` doesn't work
532     // and we concat the indices and values tensors instead.
533     AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
534       commonDtype, "add_out_sparse_cpu", [&] {
535           if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
536             s_values = s_values.mul(value);
537           }
538         });
539 
540     Tensor r_indices = at::cat({t._indices(), src._indices()}, 1);
541     Tensor r_values = at::cat({t_values, s_values}, 0).to(r.scalar_type());
542     alias_into_sparse(r, r_indices, r_values);
543 
544     // Prevent unbounded growth of nnz
545     // TODO: Improved heuristic on when to coalesce or remove need to coalesce
546     if (r._nnz() > r.numel()) {
547       auto c = r.coalesce();
548       alias_into_sparse(r, c._indices(), c._values());
549     }
550 
551     return r;
552 }
553 
554 Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, const Scalar& value);
555 
add_out_sparse_cpu(const SparseTensor & t,const SparseTensor & src,const Scalar & value,SparseTensor & r)556 SparseTensor& add_out_sparse_cpu(const SparseTensor& t, const SparseTensor& src, const Scalar& value, SparseTensor& r) {
557   if (!t.is_sparse()) {
558     return add_out_dense_sparse_cpu(r, t, src, value);
559   }
560   // TODO: This test seems a bit goofy
561   TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
562   AT_ASSERT(!t.is_cuda());  // the dispatch argument
563   TORCH_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor");
564   TORCH_CHECK(!src.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor");
565 
566   TORCH_CHECK(t.sizes().equals(src.sizes()), "add: expected sizes of 'self' and 'other' to match, but ", t.sizes(), " != ", src.sizes());
567 
568   auto commonDtype = promoteTypes(t.scalar_type(), src.scalar_type());
569 
570   TORCH_CHECK(canCast(commonDtype, r.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r.scalar_type(), " in add operation");
571 
572   if (src._nnz() == 0) {
573     return copy_sparse_to_sparse_(r, t);
574   }
575   if (t._nnz() == 0) {
576     return mul_out_sparse_scalar(r, src, value);
577   }
578 
579   TORCH_CHECK(is_same_density(t, src), "add: expected 'self' and 'other' to have same density, but 'self' has ", t.sparse_dim(), " sparse dimensions while 'other' has ", src.sparse_dim(), " sparse dimensions");
580 
581   r.resize_as_(src);
582   if (r.is_meta()) {
583     return r;
584   } else if (src._values().is_contiguous() && t._values().is_contiguous()) {
585     return add_out_sparse_contiguous(r, t, src, value, commonDtype);
586   } else {
587     return add_out_sparse_non_contiguous(r, t, src, value, commonDtype);
588   }
589 }
590 
591 // --------------------------------------------------------------------
592 // add(Tensor, SparseTensor, Scalar)
593 //    formerly known as spcadd
594 // --------------------------------------------------------------------
595 template <typename scalar_t>
add_dense_sparse_worker_non_hybrid_cpu(Tensor & r,const Scalar & value,const SparseTensor & sparse,const Tensor & indices,const Tensor & values)596 void add_dense_sparse_worker_non_hybrid_cpu(Tensor& r, const Scalar& value, const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
597   auto indices_accessor = indices.accessor<int64_t, 2>();
598   auto values_accessor = values.accessor<scalar_t, 1>();
599 
600   scalar_t* r_ptr = r.data_ptr<scalar_t>();
601   scalar_t cast_value = value.to<scalar_t>();
602   const int64_t sparse_dim = sparse.sparse_dim();
603   std::vector<int64_t> result_stride(sparse_dim);
604   for (const auto d: c10::irange(sparse_dim)) {
605     result_stride[d] = r.stride(d);
606   }
607   at::parallel_for(0, sparse._nnz(), 0, [&](int64_t start, int64_t end) {
608     for (const auto k: c10::irange(start, end)) {
609       int64_t index = r.storage_offset();
610       for (auto d: c10::irange(sparse_dim)) {
611         index += result_stride[d] * indices_accessor[d][k];
612       }
613       r_ptr[index] += cast_value * values_accessor[k];
614     }
615   });
616 }
617 
618 template <typename scalar_t>
add_dense_sparse_worker_hybrid_cpu(Tensor & r,const Scalar & value,const SparseTensor & sparse,const Tensor & indices,const Tensor & values)619 inline void add_dense_sparse_worker_hybrid_cpu(Tensor& r, const Scalar& value, const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
620 
621   // Get the dense dimension element numbers of hybrid sparse tensor
622   int64_t values_dense_size = values.stride(0);
623   TORCH_CHECK(values.is_contiguous());
624   scalar_t* v_ptr = values.data_ptr<scalar_t>();
625 
626   scalar_t* r_ptr = r.data_ptr<scalar_t>();
627   TORCH_CHECK(r_ptr != nullptr);
628 
629   auto indices_accessor = indices.accessor<int64_t, 2>();
630   scalar_t cast_value = value.to<scalar_t>();
631   auto sparse_dim = sparse.sparse_dim();
632   std::vector<int64_t> result_stride(sparse_dim);
633   for (auto d : c10::irange(sparse_dim)) {
634     result_stride[d] = r.stride(d);
635   }
636 
637   at::parallel_for(0, sparse._nnz(), 0, [&](int64_t start, int64_t end) {
638     for (auto k: c10::irange(start, end)) {
639       auto r_index = r_ptr;
640       for (auto d: c10::irange(sparse_dim)) {
641         r_index += result_stride[d] * indices_accessor[d][k];
642       }
643       auto v_index = v_ptr + k * values_dense_size;
644       at::native::cpublas::axpy<scalar_t>(values_dense_size, cast_value, v_index, 1, r_index, 1);
645     }
646   });
647 }
648 
649 template <typename scalar_t>
add_dense_sparse_worker_non_coalesced_cpu(Tensor & r,const Scalar & value,const SparseTensor & sparse,const Tensor & indices,const Tensor & values)650 inline void add_dense_sparse_worker_non_coalesced_cpu(Tensor& r, const Scalar& value,
651     const SparseTensor& sparse, const Tensor& indices, const Tensor& values) {
652 
653   // Get the dense dimension element numbers of hybrid sparse tensor
654   auto values_dense_size = values.stride(0);
655   TORCH_CHECK(values.is_contiguous());
656   scalar_t* v_ptr = values.data_ptr<scalar_t>();
657   TORCH_CHECK(v_ptr != nullptr);
658 
659   scalar_t* r_ptr = r.data_ptr<scalar_t>();
660   TORCH_CHECK(r_ptr != nullptr);
661 
662   scalar_t cast_value = value.to<scalar_t>();
663   auto sparse_dim = sparse.sparse_dim();
664 
665   auto indices_accessor = indices.accessor<int64_t, 2>();
666   int64_t result_length = r.size(0);
667   std::vector<int64_t> result_stride(sparse_dim);
668   for (auto d : c10::irange(sparse_dim)) {
669     result_stride[d] = r.stride(d);
670   }
671 
672   auto sparse_nnz = sparse._nnz();
673   int max_threads = at::get_num_threads();
674   max_threads = (result_length < max_threads) ? result_length : max_threads;
675   int64_t avg_chunk_down = result_length / max_threads;
676   std::vector<int64_t> chuck_size(max_threads);
677   for (const auto i : c10::irange(max_threads)) {
678     chuck_size[i] = avg_chunk_down;
679   }
680   //make chunk balance among threads as 211
681   for (auto i = 0 ; i < result_length % max_threads ; i++) {
682     chuck_size[i] += 1;
683   }
684   std::vector<int64_t> chuck_sum_size(max_threads + 1);
685   chuck_sum_size[0] = 0;
686   for (const auto i : c10::irange(1, max_threads)) {
687     chuck_sum_size[i] = chuck_sum_size[i - 1] + chuck_size[i - 1];
688   }
689   chuck_sum_size[max_threads] = result_length;
690   at::parallel_for(0, max_threads, 0, [&](int64_t start, int64_t end) {
691     for (auto k: c10::irange(start, end)) {
692       int64_t chunk_begin = chuck_sum_size[k];
693       int64_t chunk_end = chuck_sum_size[k + 1];
694       for (const auto n: c10::irange(sparse_nnz)) {
695         int64_t chunk_offset = indices_accessor[0][n];
696         if (chunk_offset >= chunk_begin && chunk_offset < chunk_end) {
697           int64_t r_offset = result_stride[0] * chunk_offset;
698           for (const auto d : c10::irange(1, sparse_dim)) {
699             r_offset += result_stride[d] * indices_accessor[d][n];
700           }
701           scalar_t* v_index = v_ptr + n * values_dense_size;
702           auto r_index = r_ptr + r_offset;
703           at::native::cpublas::axpy<scalar_t>(values_dense_size, cast_value, v_index, 1, r_index, 1);
704         }
705       }
706     }
707   });
708 }
709 
add_out_dense_sparse_cpu(Tensor & r,const Tensor & dense,const SparseTensor & sparse_,const Scalar & value)710 Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, const Scalar& value) {
711   TORCH_CHECK(!r.is_sparse());
712   TORCH_CHECK(!dense.is_sparse());
713   TORCH_CHECK(sparse_.is_sparse());
714 
715   TORCH_CHECK(!dense.is_cuda()); // dispatch argument
716   TORCH_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor");
717   TORCH_CHECK(!sparse_.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor");
718 
719   TORCH_CHECK(dense.sizes().equals(sparse_.sizes()), "add: expected 'self' and 'other' to have same size, but self has size ",
720     dense.sizes(), " while other has size ", sparse_.sizes(), " (FYI: dense-sparse addition does not currently support broadcasting)");
721 
722   auto commonDtype = promoteTypes(dense.scalar_type(), sparse_.scalar_type());
723   TORCH_CHECK(canCast(commonDtype, r.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r.scalar_type(), " in add operation");
724 
725   r.resize_as_(dense);
726 
727   auto sparse_nnz = sparse_._nnz();
728   if (sparse_nnz == 0) {
729     if (!is_same_tensor(r, dense)) r.copy_(dense);
730     return r;
731   }
732 
733   int64_t dense_dim = dense.dim();
734   int64_t sparse_dim = sparse_.sparse_dim();
735   Tensor resultBuffer = r;
736   if (r.scalar_type() != commonDtype) {
737     resultBuffer = dense.to(commonDtype);
738   } else if (!is_same_tensor(r, dense)) {
739     resultBuffer.copy_(dense);
740   }
741 
742   Tensor values = sparse_._values();
743   bool sparse_is_coalesced = (sparse_.is_coalesced() || sparse_nnz == 1);
744   bool result_is_contiguous = ((r.storage().data() != nullptr) && resultBuffer.is_contiguous());
745   bool value_is_contiguous = values.is_contiguous();
746   bool is_contiguous =  (result_is_contiguous && value_is_contiguous);
747 
748   SparseTensor sparse = sparse_;
749   Tensor indices = sparse_._indices();
750   Tensor valuesBuffer = values.to(commonDtype);
751   if (is_contiguous && sparse_is_coalesced) {
752     //TODO: we can optimize it for non-hybrid by not using buffers
753     if (sparse_dim == dense_dim) {
754       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
755           at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
756           commonDtype, "add_dense_sparse_non_hybrid", [&] {
757             add_dense_sparse_worker_non_hybrid_cpu<scalar_t>(resultBuffer, value, sparse_, indices, valuesBuffer);
758           });
759     } else {
760       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
761           at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
762           commonDtype, "add_dense_sparse_hybrid", [&] {
763             add_dense_sparse_worker_hybrid_cpu<scalar_t>(resultBuffer, value, sparse_, indices, valuesBuffer);
764           });
765     }
766   } else if (is_contiguous && (sparse_dim > 0)) {
767     // Handle sparse is not coalesced
768     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
769         at::ScalarType::ComplexHalf, at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
770         commonDtype, "add_dense_sparse_worker_non_coalesced", [&] {
771           add_dense_sparse_worker_non_coalesced_cpu<scalar_t>(resultBuffer, value, sparse_, indices, valuesBuffer);
772         });
773   } else {
774     // Slow path for non-contiguous values and output
775     // TODO: coalesce() performance may can be further improved
776     sparse = sparse_.coalesce();
777     indices = sparse._indices();
778     values = sparse._values();
779     valuesBuffer = values.to(commonDtype);
780     auto indices_accessor = indices.accessor<int64_t, 2>();
781     auto sparse_nnz = sparse._nnz();
782     at::parallel_for(0, sparse_nnz, 100, [&](int64_t start, int64_t end) {
783       for (auto k: c10::irange(start, end)) {
784         Tensor dstBuffer = resultBuffer;
785         for (auto d: c10::irange(sparse_dim)) {
786           dstBuffer = dstBuffer.select(0, indices_accessor[d][k]);
787         }
788         Tensor srcBuffer = valuesBuffer.select(0, k);
789         dstBuffer.add_(srcBuffer, value);
790       }
791     });
792   }
793   if (r.scalar_type() != commonDtype) {
794     r.copy_(resultBuffer);
795   }
796   return r;
797 }
798 
799 // --------------------------------------------------------------------
800 // mul(SparseTensor, SparseTensor)  [broadcasts]
801 // --------------------------------------------------------------------
802 
mul_sparse(const Tensor & self,const Tensor & other)803 Tensor mul_sparse(const Tensor& self, const Tensor& other) {
804   auto commonDtype = at::result_type(self, other);
805   // Arbitrary (dense, sparse) and (sparse, dense) multiplication is not
806   // currently supported, but (0dim-dense, sparse) and (sparse, 0dim-dense) is.
807   // Make sure we use the sparse exemplar for result.
808   auto result_options = self.is_sparse() ?
809     self.options().dtype(commonDtype) : other.options().dtype(commonDtype);
810   Tensor result = at::empty({0}, result_options);
811   return at::mul_out(result, self, other);  // redispatch!
812 }
813 
mul_sparse_(Tensor & self,const Tensor & other)814 Tensor& mul_sparse_(Tensor& self, const Tensor& other) {
815   if (self.is_sparse()) {
816     return at::mul_out(self, self, other);  // redispatch!
817   }
818   else {
819     const auto res = at::mul(self, other);
820     self.zero_();
821     self.add_(res);
822     return self;
823   }
824 }
825 
826 // A generic function to implement pointwise-like operations
827 // with index intersection between dense and sparse COO tensors.
828 // NOTE: op is always called as op(dense_values, sparse_values),
829 // so it is up to the user to supply right implementations for non-commutative
830 // operations.
831 template <typename binary_func_t>
intersection_binary_op_sparse_dense_out(const Tensor & d,const SparseTensor & s_,Tensor & res,const char * const op_name,const binary_func_t & op,const bool coalesce=false)832 Tensor& intersection_binary_op_sparse_dense_out(
833     const Tensor& d,
834     const SparseTensor& s_,
835     Tensor& res,
836     const char* const op_name,
837     const binary_func_t& op,
838     const bool coalesce = false) {
839   // compute broadcasted shape.
840   const auto res_shape = infer_size(d.sizes(), s_.sizes());
841 
842   // Short-circuit if either s_ or d is empty.
843   if (!s_._nnz() || !s_.numel() || !d.numel()) {
844     const int64_t dense_dim = s_.dense_dim();
845     const int64_t sparse_dim = static_cast<int64_t>(res_shape.size()) - dense_dim;
846     const int64_t nnz = 0;
847     const auto indices = at::empty({sparse_dim, nnz}, s_._indices().options());
848     auto res_values_shape = s_._values().sizes().vec();
849     res_values_shape[0] = nnz;
850     const auto values = at::empty(res_values_shape, s_._values().options().dtype(res.scalar_type()));
851     auto* res_impl = get_sparse_impl(res);
852     res_impl->raw_resize_(sparse_dim, dense_dim, /*size=*/res_shape);
853     res_impl->set_indices_and_values_unsafe(indices, values);
854     res_impl->set_nnz_and_narrow(nnz);
855     return res._coalesced_(true);
856   }
857 
858   const auto d_dim = d.dim();
859   const auto s_dim = s_.dim();
860 
861   // Always coalesce when sparse broadcasts over dense,
862   // because new sparse dimensions are created and
863   // repeated indices have to be eliminated because of that.
864   const auto s = (coalesce || d_dim > s_dim) ? s_.coalesce() : s_;
865 
866   const auto sparse_dim = s.sparse_dim();
867   const auto dense_dim = s.dense_dim();
868 
869   const auto s_indices = s._indices();
870   const auto s_values = s._values();
871 
872   const auto apply_op = [&](const Tensor& d_filtered) -> Tensor& {
873     const auto res_indices = s_indices.clone();
874     // to(res.scalar_type) is only performed when both d and s are 0-dim.
875     // This insures right type promotions with the following rules:
876     // op(0-dim, 0-dim).dtype == <common dtype>
877     // op(0-dim, ge-1-dim).dtype == <ge-1-dim>.dtype,
878     // where ge-1-dim is a tensor with dim >= 1.
879     // We do not cast if op is performed in-place.
880     // The cast is required if s is 0-dim non-coalesced tensor and d is 0-dim.
881     // This is because s.values is at least 1D, so
882     // op(s.values, d).dtype == s.values.dtype, but we want
883     // op(s.values, d).dtype == <common dtype>.
884     const auto values = op(d_filtered, s_values);
885     const auto res_values = is_same_tensor(s_, res) ? values : values.to(res.scalar_type());
886     auto* res_impl = get_sparse_impl(res);
887     res_impl->raw_resize_(sparse_dim, dense_dim, res_shape);
888     res_impl->set_indices_and_values_unsafe(res_indices, res_values);
889     res_impl->set_nnz_and_narrow(s._nnz());
890     return res._coalesced_(s.is_coalesced());
891   };
892 
893   // Easiest case: only dense dimensions intersect.
894   // This means only value tensors interact.
895   if (d_dim <= dense_dim) {
896     return apply_op(d);
897   }
898 
899   // Now we have intersection between sparse and dense dims.
900   const auto sparse_dim_intersec = std::min(sparse_dim, d_dim - dense_dim);
901   const auto d_start_dim_intersec = std::max<int64_t>(0, d_dim - s_dim);
902   const auto s_start_dim_intersec = std::max<int64_t>(0, s_dim - d_dim);
903 
904   // Index d with s_indices to find values which
905   // interact with s_values.
906   const auto d_filtered = [&]() -> Tensor {
907     using at::indexing::Slice;
908     using at::indexing::Ellipsis;
909     using at::indexing::TensorIndex;
910 
911     std::vector<TensorIndex> intersec_indices;
912     intersec_indices.reserve(d_dim);
913 
914     if (d_start_dim_intersec) {
915       intersec_indices.emplace_back(Ellipsis);
916     }
917     for (const auto i : c10::irange(sparse_dim_intersec)) {
918       const auto s_idx = s_start_dim_intersec + i;
919       intersec_indices.emplace_back(s_indices[s_idx]);
920     }
921     for (auto i = d_start_dim_intersec + sparse_dim_intersec; i < d_dim; ++i) {
922       intersec_indices.emplace_back(Slice());
923     }
924     // we need to expand d in the dimensions it is being indexed into
925     // to avoid out of bound indices
926     const auto d_expanded_shape = std::vector<int64_t>(
927         res_shape.end() - d_dim, res_shape.end());
928     return d.expand(d_expanded_shape).index(intersec_indices);
929   }();
930 
931   // When dims match or sparse is "larger", the result nnz is the same,
932   // so only values get modified.
933   if (s_dim >= d_dim) {
934     return apply_op(d_filtered);
935   }
936 
937   // Otherwise nnz gets larger, and both indices and values need an update.
938   const auto d_batch_shape = d.sizes().slice(0, d_start_dim_intersec);
939   const auto d_batch_len = static_cast<int64_t>(d_batch_shape.size());
940   int64_t batch_count = 1;
941   int64_t max_batch_dim = 0;
942   std::tie(batch_count, max_batch_dim) = [d_batch_shape]() -> std::tuple<int64_t, int64_t> {
943     int64_t batch_count = 1;
944     int64_t max_batch_dim = 0;
945     for (const auto& b : d_batch_shape) {
946       batch_count *= b;
947       max_batch_dim = std::max(b, max_batch_dim);
948     }
949     return std::make_tuple(batch_count, max_batch_dim);
950   }();
951 
952   const auto res_sparse_dim = static_cast<int64_t>(d_batch_shape.size()) + sparse_dim;
953   const auto res_dense_dim = dense_dim;
954   const auto s_nnz = s._nnz();
955   const auto res_nnz = batch_count * s_nnz;
956   auto res_values_shape = s_values.sizes().vec();
957   res_values_shape[0] = res_nnz;
958   const auto res_values = op(d_filtered, s_values).reshape(res_values_shape);
959   const auto res_indices = [&]() -> Tensor {
960     const auto index_buffer = at::arange(max_batch_dim, s_indices.options());
961     auto indices = at::empty({res_sparse_dim, res_nnz}, s_indices.options());
962     // fill in indices corresponding to the "batch" dimensions of d.
963     int64_t n_repeat_interleave = res_nnz;
964     int64_t n_repeat = 1;
965     for (const auto dim : c10::irange(d_batch_len)) {
966       const auto dim_size = d_batch_shape[dim];
967       n_repeat_interleave /= dim_size;
968       // fill in indices corresponding to the "batch" dimension dim.
969       // Equivalent to indices[dim].copy_(repeat_interleave(dim_index, n_repeat_interleave).repeat(n_repeat))
970       const std::initializer_list<int64_t> dim_index_expanded_shape = {n_repeat, dim_size, n_repeat_interleave};
971       const auto dim_index = index_buffer.slice(-1, 0, dim_size);
972       const auto dim_index_expanded = dim_index.unsqueeze(0).unsqueeze_(-1).expand(dim_index_expanded_shape);
973       // NOTE: indices is contiguous, so view is safe
974       indices[dim].view(dim_index_expanded_shape).copy_(dim_index_expanded);
975       n_repeat *= dim_size;
976     }
977     // fill in indices corresponding to s_indices.
978     // Equivalent to indices_sparse.copy(s_indices.repeat({1, n_repeat})
979     n_repeat = res_nnz / s_nnz;
980     auto indices_sparse = indices.narrow(0, d_batch_len, res_sparse_dim - d_batch_len);
981     const std::initializer_list<int64_t> s_indices_expanded_shape = {-1, n_repeat, s_nnz};
982     const auto s_indices_expanded = s_indices.unsqueeze(1).expand(s_indices_expanded_shape);
983     indices_sparse.view(s_indices_expanded_shape).copy_(s_indices_expanded);
984 
985     return indices;
986   }();
987   auto* res_impl = get_sparse_impl(res);
988   res_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
989   res_impl->set_indices_and_values_unsafe(res_indices, res_values);
990   res_impl->set_nnz_and_narrow(res_nnz);
991   // By design of index expansion and that s is coalesced,
992   // the result is also coalesced.
993   return res._coalesced_(true);
994 }
995 
_mul_dense_sparse_out(const Tensor & d,const Tensor & s,Tensor & res)996 Tensor& _mul_dense_sparse_out(const Tensor& d, const Tensor& s, Tensor& res) {
997   return intersection_binary_op_sparse_dense_out(d, s, res, "mul", [](const Tensor& a, const Tensor& b) -> Tensor {
998       return at::mul(a, b);
999   });
1000 }
1001 
_mul_sparse_sparse_zero_dim_out(const Tensor & zero_dim,const Tensor & other,Tensor & r)1002 Tensor& _mul_sparse_sparse_zero_dim_out(const Tensor& zero_dim, const Tensor& other, Tensor& r) {
1003   const auto is_wrapped_scalar = [](const Tensor& s) -> bool {
1004     return !s.dim() && s.is_coalesced();
1005   };
1006 
1007   const auto extract_vals_from_wrapped_scalar = [](const Tensor& s) -> Tensor {
1008     auto vals = s._values().squeeze(0);
1009     // if squeeze does not kill the dim, it means that
1010     // vals is empty with shape [0]. In such a case we
1011     // return a 0-dim empty tensor to avoid broadcasting
1012     // issues in intersection_binary_op_sparse_dense_out
1013     // when the sparse argument is actually 0-dim.
1014     if (vals.dim()) {
1015       return at::empty({}, vals.options());
1016     }
1017     return vals;
1018   };
1019 
1020   // The code dispatches to mul(dense, sparse), and the goal
1021   // is to delay calling into coalesce when converting one of
1022   // the sparse arguments to dense if possible.
1023   // This is possible when there is a 0-dim coalesced argument.
1024 
1025   // if is_wrapped_scalar(zero_dim)
1026   if (zero_dim.is_coalesced()) {
1027     const auto scalar_val = extract_vals_from_wrapped_scalar(zero_dim);
1028     return _mul_dense_sparse_out(scalar_val, other, r);
1029   }
1030   // Here zero_dim is not a wrapped scalar, so we test other.
1031   if (is_wrapped_scalar(other)) {
1032     const auto scalar_val = extract_vals_from_wrapped_scalar(other);
1033     return _mul_dense_sparse_out(scalar_val, zero_dim, r);
1034   }
1035   // Neither of inputs is a wrapped scalar, but zero_dim
1036   // is at least 0-dim, so we coalesce it to convert to
1037   // a scalar.
1038   const auto scalar_val = extract_vals_from_wrapped_scalar(zero_dim.coalesce());
1039   return _mul_dense_sparse_out(scalar_val, other, r);
1040 }
1041 
1042 DEFINE_DISPATCH(mul_sparse_sparse_out_stub);
1043 
_mul_sparse_sparse_out(const Tensor & x,const Tensor & y,Tensor & res)1044 Tensor& _mul_sparse_sparse_out(const Tensor& x, const Tensor& y, Tensor& res) {
1045   mul_sparse_sparse_out_stub(res.device().type(), res, x, y);
1046   return res;
1047 }
1048 
mul_out_sparse_cpu(const Tensor & t_,const Tensor & src_,Tensor & r)1049 SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, Tensor& r) {
1050   AT_ASSERT(!t_.is_cuda()); // dispatch argument
1051   TORCH_CHECK(!r.is_cuda(), "mul: expected 'out' to be CPU tensor, but got CUDA tensor");
1052   TORCH_CHECK(!src_.is_cuda(), "mul: expected 'other' to be a CPU tensor, but got a CUDA tensor");
1053   // case mul(sparse, dense)
1054   if (!src_.is_sparse()) {
1055     return _mul_dense_sparse_out(src_, t_, r);
1056   }
1057   // case mul(dense, sparse)
1058   if (!t_.is_sparse()) {
1059     return _mul_dense_sparse_out(t_, src_, r);
1060   }
1061 
1062   // case mul(sparse, sparse) with a 0-dim input.
1063   if (!src_.dim()) {
1064     return _mul_sparse_sparse_zero_dim_out(src_, t_, r);
1065   }
1066   if (!t_.dim()) {
1067     return _mul_sparse_sparse_zero_dim_out(t_, src_, r);
1068   }
1069 
1070   const auto is_equal_size_inputs = t_.sizes().equals(src_.sizes());
1071 
1072   // mul(sparse, sparse) with inputs which broadcast only in dense dims
1073   if (!is_equal_size_inputs) {
1074     _mul_sparse_sparse_out(t_, src_, r);
1075     return r;
1076   }
1077 
1078   TORCH_CHECK(is_equal_size_inputs, "mul: expected 'self' and 'other' to have same sizes when both are sparse"
1079       ", but ", t_.sizes(), " != ", src_.sizes());
1080 
1081   // Short circuit when there is zero nnz
1082   // Not strictly necessary, but there are tests checking whether
1083   // resize in mul fails if run on tensors coming from .data/.detach.
1084   if (!t_._nnz() || !src_._nnz()) {
1085     r.resize_as_(t_);
1086     return r.zero_();
1087   }
1088 
1089   // _mul_sparse_sparse_out is faster for large inputs
1090   // and when either of the inputs is uncoalesced.
1091   if (!t_.is_coalesced() || !src_.is_coalesced()) {
1092     _mul_sparse_sparse_out(t_, src_, r);
1093     return r;
1094   }
1095 
1096   // Otherwise _mul_sparse_sparse_out might be slower
1097   // than the brute-force solution below.
1098 
1099   SparseTensor t = t_.coalesce();
1100   SparseTensor src = src_.coalesce();
1101 
1102   // saving those because they can be overwritten when doing in-place operations
1103   int64_t t_nnz = t._nnz(), s_nnz = src._nnz();
1104   int64_t max_nnz = std::min(t_nnz, s_nnz);  // multiply by zero is zero, and can be dropped
1105   int64_t sparse_dim = src.sparse_dim();
1106   Tensor t_indices = t._indices();
1107   Tensor src_indices = src._indices();
1108   Tensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options());
1109 
1110   int64_t r_i = 0, t_i = 0, s_i = 0;
1111 
1112   auto commonDtype = promoteTypes(t_.scalar_type(), src_.scalar_type());
1113   TORCH_CHECK(canCast(commonDtype, r.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r.scalar_type(), " in mul operation");
1114 
1115   Tensor t_values = t._values().to(commonDtype);
1116   Tensor s_values = src._values().to(commonDtype);
1117 
1118   Tensor r_buffer = new_values_with_size_of(t_values, max_nnz).zero_();
1119 
1120   // NB: relies on nnz test above
1121   auto t_indices_accessor = t_indices.accessor<int64_t, 2>();
1122   auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
1123   auto src_indices_accessor = src_indices.accessor<int64_t, 2>();
1124 
1125   // Check if we can find matching indices, and if so, write an
1126   // entry to the result indices vector.  Returns true if matching
1127   // indices were found.
1128   auto index_preamble = [&]() {
1129     for (auto d: c10::irange(sparse_dim)) {
1130       if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) {
1131         t_i++;
1132         return false;
1133       }
1134       if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) {
1135         s_i++;
1136         return false;
1137       }
1138     }
1139     for (auto d: c10::irange(sparse_dim)) {
1140       r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i];
1141     }
1142     return true;
1143   };
1144 
1145   if (t_values.dim() > 1) {
1146     while (t_i < t_nnz && s_i < s_nnz) {
1147       if (!index_preamble()) continue;
1148       r_buffer.select(0, r_i).addcmul_(t_values.select(0, t_i), s_values.select(0, s_i));
1149       r_i++;
1150       t_i++;
1151       s_i++;
1152     }
1153   } else {
1154     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
1155         at::ScalarType::ComplexHalf, at::ScalarType::BFloat16, at::ScalarType::Half,
1156         commonDtype, "mul_out_sparse", [&] {
1157           auto r_accessor = r_buffer.accessor<scalar_t, 1>();
1158           auto t_accessor = t_values.accessor<scalar_t, 1>();
1159           auto s_accessor = s_values.accessor<scalar_t, 1>();
1160 
1161           while (t_i < t_nnz && s_i < s_nnz) {
1162             if (!index_preamble()) continue;
1163             r_accessor[r_i] = t_accessor[t_i] * s_accessor[s_i];
1164             r_i++;
1165             t_i++;
1166             s_i++;
1167           }
1168         }
1169     );
1170   }
1171 
1172   r.resize_as_(src);
1173   Tensor r_values = r_buffer.to(r.scalar_type());
1174   get_sparse_impl(r)->set_indices_and_values_unsafe(r_indices, r_values);
1175   get_sparse_impl(r)->set_nnz_and_narrow(r_i);
1176   return r._coalesced_(true);
1177 }
1178 
1179 // --------------------------------------------------------------------
1180 // addmm(D1, S, D2, beta, alpha) -> D  [broadcasts]
1181 //
1182 // D = beta * D1 + alpha * mm(S, D2)
1183 // --------------------------------------------------------------------
1184 
1185 template <typename scalar_t>
s_addmm_out_sparse_dense_worker(int64_t nnz,int64_t dim_i,int64_t dim_j,int64_t dim_k,Tensor & r,const Scalar & beta,const Tensor & t,const Scalar & alpha,const Tensor & indices,const Tensor & values,const Tensor & dense)1186 void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& indices, const Tensor& values, const Tensor& dense) {
1187 
1188   // r_ = alpha * sparse * dense
1189   scalar_t cast_alpha = alpha.to<scalar_t>();
1190   scalar_t cast_beta = beta.to<scalar_t>();
1191 
1192   if (cast_beta == static_cast<scalar_t>(0)) {
1193     r.zero_();
1194   } else if (cast_beta == static_cast<scalar_t>(1)) {
1195     if (!is_same_tensor(r, t)) {
1196       r.copy_(t);
1197     }
1198   } else {
1199     at::mul_out(r, t, scalar_to_tensor(beta));
1200   }
1201 
1202   auto indices_accessor = indices.accessor<int64_t, 2>();
1203 
1204   auto values_accessor = values.accessor<scalar_t, 1>();
1205   scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
1206   scalar_t* r_ptr = r.data_ptr<scalar_t>();
1207 
1208   int64_t dense_stride0 = dense.stride(0);
1209   int64_t dense_stride1 = dense.stride(1);
1210   int64_t r_stride0 = r.stride(0);
1211   int64_t r_stride1 = r.stride(1);
1212   for (auto i: c10::irange(nnz)) {
1213     scalar_t val = values_accessor[i];
1214     int64_t row = indices_accessor[0][i];
1215     int64_t col = indices_accessor[1][i];
1216     if (col >= 0 && col < dim_j && row >= 0 && row < dim_i) {
1217       // AXPY call is no-op over an empty vector
1218       if (dim_k == 0) {
1219         continue;
1220       }
1221       at::native::cpublas::axpy<scalar_t>(dim_k,
1222             cast_alpha * val,
1223             dense_ptr + col * dense_stride0, dense_stride1,
1224             r_ptr + row * r_stride0, r_stride1);
1225     } else {
1226       if (col < 0 || col >= dim_j) {
1227         AT_ERROR("addmm: index out of column bound: ", col, " not between 1 and ", dim_j);
1228       } else {
1229         AT_ERROR("addmm: index out of row bound: ", row, " not between 1 and ", dim_i);
1230       }
1231     }
1232   }
1233 };
1234 
s_addmm_out_sparse_dense_cpu(Tensor & r,const Tensor & t,const SparseTensor & sparse_,const Tensor & dense,const Scalar & beta,const Scalar & alpha)1235 static Tensor& s_addmm_out_sparse_dense_cpu(
1236     Tensor& r,
1237     const Tensor& t,
1238     const SparseTensor& sparse_,
1239     const Tensor& dense,
1240     const Scalar& beta,
1241     const Scalar& alpha) {
1242   // TODO: This error message seems awfully opaque
1243   TORCH_CHECK(
1244       t.is_cpu(),
1245       "Expected all tensors to be on the same device. addmm expected 't' to be CPU tensor, but got tensor on ",
1246       t.device());
1247   TORCH_CHECK(
1248       r.is_cpu(),
1249       "Expected all tensors to be on the same device. addmm: expected 'out' to be CPU tensor, but got tensor on ",
1250       r.device());
1251   TORCH_CHECK(
1252       sparse_.is_cpu(),
1253       "Expected all tensors to be on the same device. addmm: expected 'mat1' to be a CPU tensor, but got tensor on ",
1254       sparse_.device());
1255   TORCH_CHECK(
1256       dense.is_cpu(),
1257       "Expected all tensors to be on the same device. addmm: expected 'mat2' to be a CPU tensor, but got tensor on ",
1258       dense.device());
1259 
1260   TORCH_CHECK(
1261       r.layout() == kStrided,
1262       "addmm_sparse_dense: expected strided result tensor, got tensor with layout ",
1263       r.layout());
1264   TORCH_CHECK(
1265       t.layout() == kStrided,
1266       "addmm_sparse_dense: expected 't' to have strided layout, got tensor with layout ",
1267       t.layout());
1268   TORCH_CHECK(
1269       sparse_.layout() == kSparse && dense.layout() == kStrided,
1270       "addmm_sparse_dense: expected either 'mat1' to have sparse layout and 'mat2' to have strided layout, got 'mat1' with layout ",
1271       sparse_.layout(),
1272       " and 'mat2' with layout ",
1273       dense.layout());
1274 
1275   TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: matrices expected, got ", sparse_.sparse_dim(), "D tensor");
1276   TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: scalar values expected, got ", sparse_.dense_dim(), "D values");
1277   TORCH_CHECK(dense.dim() == 2, "addmm: matrices expected, got ", dense.dim(), "D tensor");
1278 
1279   // ixj * jxk = ixk
1280   int64_t dim_i = sparse_.size(0);
1281   int64_t dim_j = sparse_.size(1);
1282   int64_t dim_k = dense.size(1);
1283 
1284   TORCH_CHECK(dense.size(0) == dim_j,
1285       "addmm: Argument #3 (dense): Expected dim 0 size ", dim_j, ", got ", dense.size(0));
1286   TORCH_CHECK(t.size(0) == dim_i,
1287       "addmm: Argument #1 (t): Expected dim 0 size ", dim_i, ", got ", t.size(0));
1288   TORCH_CHECK(t.size(1) == dim_k,
1289       "addmm: Argument #1 (t): Expected dim 1 size ", dim_k, ", got ", t.size(1));
1290 
1291   r.resize_({dim_i, dim_k});
1292 
1293   int64_t nnz        = sparse_._nnz();
1294 
1295   if (nnz == 0) {
1296     at::mul_out(r, t, at::scalar_tensor(beta, r.options()));
1297     return r;
1298   }
1299 
1300   Tensor indices = sparse_._indices();
1301   Tensor values      = sparse_._values();
1302 
1303   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
1304       values.scalar_type(), "addmm_sparse_dense", [&] {
1305         s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense);
1306       }
1307   );
1308 
1309   return r;
1310 }
1311 
addmm_out_sparse_dense_cpu(const Tensor & self,const SparseTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)1312 Tensor& addmm_out_sparse_dense_cpu(
1313     const Tensor& self,
1314     const SparseTensor& mat1,
1315     const Tensor& mat2,
1316     const Scalar& beta,
1317     const Scalar& alpha,
1318     Tensor& result) {
1319   c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
1320   return s_addmm_out_sparse_dense_cpu(result, *b_self, mat1, mat2, beta, alpha);
1321 }
1322 
s_addmm_sparse_dense_cpu(const Tensor & t,const SparseTensor & sparse,const Tensor & dense,const Scalar & beta,const Scalar & alpha)1323 static Tensor s_addmm_sparse_dense_cpu(
1324     const Tensor& t,
1325     const SparseTensor& sparse,
1326     const Tensor& dense,
1327     const Scalar& beta,
1328     const Scalar& alpha
1329 ) {
1330   Tensor r = at::empty({0}, t.options());
1331   s_addmm_out_sparse_dense_cpu(r, t, sparse, dense, beta, alpha);
1332   return r;
1333 }
1334 
addmm_sparse_dense_cpu(const Tensor & self,const SparseTensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)1335 Tensor addmm_sparse_dense_cpu(
1336     const Tensor& self,
1337     const SparseTensor& mat1,
1338     const Tensor& mat2,
1339     const Scalar& beta,
1340     const Scalar& alpha
1341 ) {
1342   c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
1343   return s_addmm_sparse_dense_cpu(*b_self, mat1, mat2, beta, alpha);
1344 }
1345 
s_addmm_sparse_dense_cpu_(Tensor & t,const SparseTensor & sparse,const Tensor & dense,const Scalar & beta,const Scalar & alpha)1346 Tensor& s_addmm_sparse_dense_cpu_(
1347     Tensor& t,
1348     const SparseTensor& sparse,
1349     const Tensor& dense,
1350     const Scalar& beta,
1351     const Scalar& alpha
1352 ) {
1353   return s_addmm_out_sparse_dense_cpu(t, t, sparse, dense, beta, alpha);
1354 }
1355 
1356 // NB: Purposely no broadcasting version of addmm inplace
1357 
_sparse_addmm(const Tensor & t,const SparseTensor & sparse,const Tensor & dense,const Scalar & beta,const Scalar & alpha)1358 Tensor _sparse_addmm(
1359   const Tensor& t,
1360   const SparseTensor& sparse,
1361   const Tensor& dense,
1362   const Scalar& beta,
1363   const Scalar& alpha
1364 ) {
1365   // _sparse_addmm forward is functionally equivalent to addmm; it's
1366   // just the backward that is different.  This technically does an
1367   // unnecessary redispatch, I was too lazy to make it not do that
1368   return at::addmm(t, sparse, dense, beta, alpha);
1369 }
1370 
_sparse_mm(const Tensor & mat1,const Tensor & mat2)1371 Tensor _sparse_mm(
1372   const Tensor& mat1,
1373   const Tensor& mat2
1374 ) {
1375   if (mat1.is_sparse() && mat2.is_sparse()) {
1376     return at::_sparse_sparse_matmul(mat1, mat2);
1377   }
1378   if (mat1.is_sparse() || at::sparse_csr::is_sparse_compressed(mat1)) {
1379     Tensor t = at::zeros({mat1.size(-2), mat2.size(-1)}, mat2.options());
1380     return at::_sparse_addmm(t, mat1, mat2, 0, 1);
1381   }
1382   Tensor t = at::zeros({mat1.size(-2), mat2.size(-1)}, mat1.options());
1383   return at::_sparse_addmm(t.transpose(-2, -1), mat2.transpose(-2, -1), mat1.transpose(-2, -1), 0, 1).transpose(-2, -1);
1384 }
1385 
1386 // NB: Despite its suggestive name, this actually only exists so that
1387 // we can redispatch to addmm_out; this is NOT an implementation of
1388 // the sparse masking version of mm
_sparse_mm_out(const SparseTensor & sparse,const Tensor & dense,SparseTensor & result)1389 SparseTensor& _sparse_mm_out(const SparseTensor& sparse,
1390   const Tensor& dense,
1391   SparseTensor& result) {
1392   Tensor t = at::zeros({}, dense.options());
1393   return at::addmm_out(result, t, sparse, dense, 0, 1);  // redispatch!
1394 }
1395 
_sparse_mm(const Tensor & mat1,const Tensor & mat2,const c10::string_view reduce)1396 Tensor _sparse_mm(const Tensor& mat1, const Tensor& mat2, const c10::string_view reduce) {
1397   // result: out, arg_out
1398   auto result = at::_sparse_mm_reduce_impl(mat1, mat2, reduce);
1399   return std::get<0>(result);
1400 }
1401 
1402 // --------------------------------------------------------------------
1403 // hspmm(SparseTensor mat1, Tensor mat2)
1404 // --------------------------------------------------------------------
1405 
hspmm_out_sparse_cpu(const SparseTensor & sparse_,const Tensor & dense,SparseTensor & r)1406 SparseTensor& hspmm_out_sparse_cpu(const SparseTensor& sparse_, const Tensor& dense, SparseTensor& r) {
1407   // TODO: Make this a real argument
1408   Scalar alpha = 1;
1409 
1410   AT_ASSERT(!sparse_.is_cuda()); // dispatch argument
1411   TORCH_CHECK(!r.is_cuda(), "hspmm: expected 'out' to be CPU tensor, but got CUDA tensor");
1412   TORCH_CHECK(!dense.is_cuda(), "hspmm: expected 'other' to be a CPU tensor, but got a CUDA tensor");
1413 
1414   TORCH_CHECK(sparse_.sparse_dim() == 2,
1415       "hspmm: Argument #2: matrices expected, got ", sparse_.sparse_dim(), "D tensor");
1416   TORCH_CHECK(sparse_.dense_dim() == 0,
1417       "hspmm: Argument #2: scalar values expected, got ", sparse_.dense_dim(), "D values");
1418   TORCH_CHECK(dense.dim() == 2,
1419       "hspmm: Argument #3: matrices expected, got ", dense.dim(), "D tensor");
1420 
1421   int64_t m = sparse_.size(0);
1422   int64_t k = sparse_.size(1);
1423   int64_t n = dense.size(1);
1424 
1425   TORCH_CHECK(dense.size(0) == k,
1426       "hspmm: Argument #3: Expected dim 0 size ", k, ", got ", dense.size(0));
1427 
1428   get_sparse_impl(r)->raw_resize_(1, 1, {m, n});
1429 
1430   SparseTensor sparse = sparse_.coalesce();
1431 
1432   int64_t nnz = sparse._nnz();
1433 
1434   if (nnz == 0) {
1435     r.zero_();
1436     return r;
1437   }
1438 
1439   Tensor indices = at::empty({1, nnz}, at::initialTensorOptions().dtype(kLong));
1440 
1441   // Initialize the sparse matrix that will be used with spaddmm to send rows
1442   // from the dense matrix to rows of the output's value tensor
1443   SparseTensor newSparse = sparse.clone();
1444   Tensor spIndices = newSparse._indices();
1445   Tensor valueIndices = spIndices.select(0, 0);
1446 
1447   // Compute output indices
1448   auto valueIndices_accessor = valueIndices.accessor<int64_t, 1>();
1449   auto indices_accessor = indices.accessor<int64_t, 2>();
1450 
1451   int64_t i = -1, prevIdx = -1;
1452   for (const auto j : c10::irange(nnz)) {
1453     int64_t currIdx = valueIndices_accessor[j];
1454     if (currIdx != prevIdx) {
1455       indices_accessor[0][++i] = currIdx;
1456       prevIdx = currIdx;
1457     }
1458     valueIndices_accessor[j] = i;
1459   }
1460   int64_t outNnz = i + 1;
1461   indices.resize_({1, outNnz});
1462   Tensor values = at::empty({outNnz, n}, dense.options());
1463 
1464   std::vector<int64_t> new_size = get_sparse_impl(newSparse)->sizes().vec();
1465   new_size[0] = outNnz;
1466   get_sparse_impl(newSparse)->raw_resize_(get_sparse_impl(newSparse)->sparse_dim(), get_sparse_impl(newSparse)->dense_dim(), new_size);
1467 
1468   // Compute output values tensor with sparse * dense multiplication
1469   s_addmm_out_sparse_dense_cpu(values, values, newSparse, dense, 0, alpha);
1470   get_sparse_impl(r)->set_indices_and_values_unsafe(indices, values);
1471 
1472   return r;
1473 }
1474 
hspmm_sparse_cpu(const SparseTensor & sparse,const Tensor & dense)1475 SparseTensor hspmm_sparse_cpu(const SparseTensor& sparse, const Tensor& dense) {
1476   SparseTensor r = at::empty({0}, sparse.options());
1477   hspmm_out_sparse_cpu(sparse, dense, r);
1478   return r;
1479 }
1480 
1481 // --------------------------------------------------------------------
1482 // sspaddmm(S1, S2, D, beta, alpha) -> S
1483 //
1484 // S = beta * S1 + alpha * mm(S2, D)
1485 // --------------------------------------------------------------------
1486 
_sspaddmm_out_cpu(const SparseTensor & t,const SparseTensor & sparse_,const Tensor & dense,const Scalar & beta,const Scalar & alpha,SparseTensor & r)1487 SparseTensor& _sspaddmm_out_cpu(
1488     const SparseTensor& t,
1489     const SparseTensor& sparse_,
1490     const Tensor& dense,
1491     const Scalar& beta,
1492     const Scalar& alpha,
1493     SparseTensor& r) {
1494   AT_ASSERT(!t.is_cuda()); // dispatch argument
1495   TORCH_CHECK(!r.is_cuda(), "sspaddmm: expected 'out' to be CPU tensor, but got CUDA tensor");
1496   TORCH_CHECK(!sparse_.is_cuda(), "sspaddmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
1497   TORCH_CHECK(!dense.is_cuda(), "sspaddmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
1498 
1499   TORCH_CHECK(sparse_.sparse_dim() == 2,
1500       "sspaddmm: Argument #2: matrices expected, got ", sparse_.sparse_dim(), "D tensor");
1501   TORCH_CHECK(sparse_.dense_dim() == 0,
1502       "sspaddmm: Argument #2: scalar values expected, got ", sparse_.dense_dim(), "D values");
1503   TORCH_CHECK(dense.dim() == 2,
1504       "sspaddmm: Argument #2: matrices expected, got ", dense.dim(), "D tensor");
1505 
1506   SparseTensor sparse = sparse_.coalesce();
1507 
1508   // ixj * jxk = ixk
1509   int64_t dim_i = sparse.size(0);
1510   int64_t dim_j = sparse.size(1);
1511   int64_t dim_k = dense.size(1);
1512 
1513   // NB: This has to occur before the checks, because r may alias t.
1514   // See test_saddmm
1515   get_sparse_impl(r)->raw_resize_(2, 0, {dim_i, dim_k});
1516 
1517   TORCH_CHECK(dense.size(0) == dim_j,
1518       "sspaddmm: Argument #3: Expected dim 0 size ", dim_j, ", got ", dense.size(0));
1519   TORCH_CHECK(t.size(0) == dim_i,
1520       "sspaddmm: Argument #1: Expected dim 0 size ", dim_i, ", got ", t.size(0));
1521   TORCH_CHECK(t.size(1) == dim_k,
1522       "sspaddmm: Argument #1: Expected dim 1 size ", dim_k, ", got ", t.size(1));
1523 
1524   int64_t nnz        = sparse._nnz();
1525   // We have to make indices contiguous as we use indices.data_ptr in _to_csr which assumes row-contiguous storage
1526   Tensor indices = sparse._indices().contiguous();
1527   Tensor values      = sparse._values();
1528 
1529   Tensor csr = coo_to_csr(indices.data_ptr<int64_t>(), dim_i, nnz);
1530 
1531   int64_t t_nnz = t._nnz();
1532   int64_t r_nnz = nnz * dim_k + t_nnz;
1533   Tensor newi = at::empty({2, r_nnz}, kLong);
1534   Tensor newv = at::zeros(
1535       {r_nnz},
1536       optTypeMetaToScalarType(values.options().dtype_opt()),
1537       values.options().layout_opt(),
1538       values.options().device_opt(),
1539       values.options().pinned_memory_opt());
1540 
1541   if (t_nnz != 0) {
1542     Tensor narrowi = newi.narrow(1, 0, t_nnz);
1543     Tensor narrowv = newv.narrow(0, 0, t_nnz);
1544 
1545     narrowi.copy_(t._indices());
1546     narrowv.copy_(t._values());
1547     newv.mul_(beta);
1548   }
1549 
1550   // sparse = sparse * dense
1551   int64_t p = t_nnz;
1552 
1553   auto csr_accessor = csr.accessor<int64_t, 1>();
1554   auto indices_accessor = indices.accessor<int64_t, 2>();
1555   auto newi_accessor = newi.accessor<int64_t, 2>();
1556 
1557   int64_t dense_stride0 = dense.stride(0);
1558   int64_t dense_stride1 = dense.stride(1);
1559   int64_t newv_stride0 = newv.stride(0);
1560 
1561   AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
1562       values.scalar_type(), "sspmm", [&] {
1563         auto values_accessor = values.accessor<scalar_t, 1>();
1564         scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
1565         scalar_t* newv_ptr = newv.data_ptr<scalar_t>();
1566         scalar_t cast_alpha = alpha.to<scalar_t>();
1567 
1568         for (const auto h : c10::irange(dim_i)) {
1569           int64_t i_start = csr_accessor[h];
1570           int64_t i_end = csr_accessor[h+1];
1571           for (const auto i : c10::irange(i_start, i_end)) {
1572             scalar_t val = values_accessor[i];
1573             int64_t col = indices_accessor[1][i];
1574             if (col >= 0 && col < dim_j) {
1575               at::native::cpublas::axpy<scalar_t>(dim_k,
1576                   cast_alpha * val,
1577                   dense_ptr + col * dense_stride0, dense_stride1,
1578                   newv_ptr + p * newv_stride0, 1);
1579             } else {
1580               AT_ERROR("index out of bound. sspmm: ", col, " not between 1 and ", dim_j);
1581             }
1582           }
1583           // Fill up the indices with the right values
1584           if (i_start != i_end) {
1585             for (const auto i : c10::irange(dim_k)) {
1586               newi_accessor[0][p+i] = h;
1587               newi_accessor[1][p+i] = i;
1588             }
1589             p += dim_k;
1590           }
1591         }
1592       }
1593   );
1594 
1595   // to avoid a clone
1596   get_sparse_impl(r)->set_indices_and_values_unsafe(newi, newv);
1597   get_sparse_impl(r)->set_nnz_and_narrow(p);
1598 
1599   return r;
1600 }
1601 
1602 // sparse, sparse, sparse, dense, real, real -> sparse
_sspaddmm_out_only_sparse(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)1603 Tensor& _sspaddmm_out_only_sparse(const Tensor& self,
1604     const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Tensor& result) {
1605   AT_ERROR("tensor.sspaddmm(...) can only be called on sparse tensors");
1606 }
1607 
1608 // sparse, dense -> sparse
smm(const Tensor & self,const Tensor & mat2)1609 Tensor smm(const Tensor& self, const Tensor& mat2) {
1610   auto result = at::empty({0}, self.options());
1611   at::sspaddmm_out(result, result, self, mat2, 0.0, 1.0);
1612   return result;
1613 }
1614 
1615 // sparse, sparse, dense, real, real -> sparse
sspaddmm(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)1616 Tensor sspaddmm(const Tensor& self, const Tensor& mat1, const Tensor& mat2,
1617     const Scalar& beta, const Scalar& alpha) {
1618   auto result = at::empty({0}, self.options());
1619   at::sspaddmm_out(result, self, mat1, mat2, beta, alpha);
1620   return result;
1621 }
1622 
1623 // --------------------------------------------------------------------
1624 // sparse.sum()
1625 //
1626 // This implementation calls coalesce() to do the sum reduction on
1627 // sparse dims. Ideally in the future there should be unified reduction function
1628 // for ops like sum, max, and min.
1629 // --------------------------------------------------------------------
_sparse_sum(const SparseTensor & input)1630 Tensor _sparse_sum(const SparseTensor& input) {
1631   return input.coalesce().values().sum();
1632 }
1633 
_sparse_sum(const SparseTensor & input,ScalarType dtype)1634 Tensor _sparse_sum(const SparseTensor& input, ScalarType dtype) {
1635   // don't have to do a conversion to the correct dtype first
1636   // just need to setup the accumulator correctly
1637   return input.coalesce().values().sum(dtype);
1638 }
1639 
_sparse_sum(const SparseTensor & input,IntArrayRef dims_to_sum,ScalarType dtype)1640 Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum, ScalarType dtype) {
1641   return at::_sparse_sum(input.to(dtype), dims_to_sum);
1642 }
1643 
_sparse_sum(const SparseTensor & input,IntArrayRef dims_to_sum)1644 Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
1645   const int64_t input_dim = input.dim();
1646   auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
1647   auto dims_to_sum_v = dims_to_sum.vec();
1648   maybe_wrap_dims(dims_to_sum_v, input_dim);
1649 
1650   Tensor indices = input._indices();
1651   Tensor values = input._values();
1652   IntArrayRef sizes = input.sizes();
1653   const int64_t sparse_dim = input.sparse_dim();
1654 
1655   auto dims_to_keep_v = std::vector<int64_t>();
1656   auto dense_dims_to_sum_v = std::vector<int64_t>();
1657   for (const auto d : c10::irange(input_dim)) {
1658     if (dims_to_sum_b[d]) {
1659       if (d >= sparse_dim) dense_dims_to_sum_v.emplace_back(d + 1 - sparse_dim);
1660     }
1661     else {
1662       dims_to_keep_v.emplace_back(d);
1663     }
1664   }
1665   const int64_t sparse_dims_to_sum_size = dims_to_sum_v.size() - dense_dims_to_sum_v.size();
1666   const bool sum_all_sparse_dim = (sparse_dim == sparse_dims_to_sum_size);
1667   const bool sum_dense_dim = (!dense_dims_to_sum_v.empty());
1668 
1669   // new values
1670   Tensor new_values;
1671   if (sum_dense_dim) {
1672     new_values = values.sum(dense_dims_to_sum_v);
1673   }
1674   else {
1675     new_values = values.clone(at::MemoryFormat::Contiguous);
1676   }
1677 
1678   if (sum_all_sparse_dim) {
1679     // return a dense tensor if sum over all sparse dims
1680     new_values = new_values.sum(0);
1681     return new_values;
1682   }
1683   else { // !sum_all_sparse_dim
1684     // new indices
1685     Tensor new_indices;
1686     if (sparse_dims_to_sum_size == 0) {
1687       new_indices = indices.clone(at::MemoryFormat::Contiguous);
1688     }
1689     else {
1690       new_indices = at::empty({sparse_dim - sparse_dims_to_sum_size, input._nnz()}, indices.options());
1691       for (auto i: c10::irange(dims_to_keep_v.size())) {
1692         int64_t d = dims_to_keep_v[i];
1693         if (d < sparse_dim) new_indices[i].copy_(indices[d]);
1694         else break;
1695       }
1696     }
1697 
1698     // new size
1699     int64_t new_sparse_dim = new_indices.size(0);
1700     int64_t new_dense_dim = new_values.dim() - 1; // exclude nnz dim
1701     std::vector<int64_t> new_sizes;
1702     new_sizes.reserve(dims_to_keep_v.size());
1703     for (auto d : dims_to_keep_v) new_sizes.emplace_back(sizes[d]);
1704     if (sum_all_sparse_dim) new_sizes.emplace(new_sizes.begin(), 1);
1705 
1706     // use coalesce() to do sum reduction
1707     bool is_coalesced = false;  // TODO: can we use input.is_coalesced()?
1708     SparseTensor new_sparse = at::_sparse_coo_tensor_with_dims_and_tensors(new_sparse_dim, new_dense_dim, new_sizes, new_indices, new_values, input.options(), is_coalesced);
1709     new_sparse = new_sparse.coalesce();
1710     return new_sparse;
1711   }
1712 
1713 }
1714 // --------------------------------------------------------------------
1715 // NOTE [ sparse.sum() backward ]
1716 //
1717 // When sum over sparse_dim, backward scatters gradients from grad tensor to input tensor.
1718 // Grad and input need to align indices over sparse_dim that are not summed (given
1719 // input.spares_dim >= grad.sparse_dim). Implementation here compares each pair of
1720 // indices between grad and input. When a matching indices pair (input_i, grad_i) is found,
1721 // copy grad.values[grad_i] -> input_grad.values[input_i]. E.g.,
1722 //
1723 //  input.sparse_dim = [5, 5]
1724 //  input.indices = [[0, 0, 1, 2, 2, 3, 4, 4],
1725 //                   [1, 4, 4, 0, 1, 3, 2, 4]]
1726 //  input.values =   [0, 1, 2, 3, 4, 5, 6, 7]
1727 //  ...
1728 //  sparse.sum(input, [0])
1729 //  backward(...)
1730 //  ...
1731 //  grad.indices = [[0, 1, 2, 3]]
1732 //  grad.values =   [1, 2, 0, 4]
1733 //
1734 // # after indices matching
1735 //         input         grad
1736 //        [[0, 1],   ->  [1]
1737 //         [0, 4],   ->  [ ]
1738 //         [1, 4],   ->  [ ]
1739 //         [2, 0],   ->  [0]
1740 //         [2, 1],   ->  [1]
1741 //         [3, 3],   ->  [3]
1742 //         [4, 2],   ->  [2]
1743 //         [4, 4]])  ->  [ ]
1744 //
1745 // input_grad.indices = [[0, 0, 1, 2, 2, 3, 4, 4],
1746 //                       [1, 4, 4, 0, 1, 3, 2, 4]]
1747 // input_grad.values =   [2, 0, 0, 1, 2, 4, 0, 0]
1748 //
1749 // Note that we allow input to be uncoalesced in the forward,
1750 // we have to coalesce input at the backward, because grad-of-input
1751 // take the same indices as input, if input is not coalesced, then
1752 // coalescing grad-of-input may add up grad values for a duplicate indices,
1753 // and hence generates a wrong grad-of-input.
1754 //
1755 // Other edge cases:
1756 // - assign zero values to input gradients if cannot find matched indices at grad
1757 // - grad.values might have zeros
1758 // --------------------------------------------------------------------
_sparse_sum_backward_cpu(const Tensor & grad_,const SparseTensor & input_,IntArrayRef dims_to_sum)1759 Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_, IntArrayRef dims_to_sum) {
1760   TORCH_CHECK(!grad_.is_cuda(), "_sparse_sum_backward_cpu: expected 'grad_' to be CPU tensor, but got CUDA tensor");
1761   TORCH_CHECK(!input_.is_cuda(), "_sparse_sum_backward_cpu: expected 'input_' to be CPU tensor, but got CUDA tensor");
1762 
1763   // Short circuit if grad is either zero or empty.
1764   if (((grad_.is_sparse() || at::sparse_csr::is_sparse_compressed(grad_)) && !grad_._nnz()) || !grad_.numel()) {
1765     return at::zeros_like(input_);
1766   }
1767 
1768   auto input = input_.coalesce();
1769   const int64_t input_dim = input.dim();
1770   auto dims_to_sum_b = dim_list_to_bitset(dims_to_sum, input_dim);
1771   auto dims_to_sum_v = dims_to_sum.vec();
1772   maybe_wrap_dims(dims_to_sum_v, input_dim);
1773 
1774   Tensor input_indices = input._indices();
1775   Tensor input_values = input._values();
1776   IntArrayRef input_sizes = input.sizes();
1777   const int64_t input_sparse_dim = input.sparse_dim();
1778   const int64_t input_dense_dim = input.dense_dim();
1779   const int64_t input_nnz = input._nnz();
1780 
1781   int64_t sparse_dims_to_sum_size = 0;
1782   auto sparse_dims_to_keep_v = std::vector<int64_t>();
1783   auto dense_dims_to_sum_v = std::vector<int64_t>();
1784   for (auto d: c10::irange(input_dim)) {
1785     if (dims_to_sum_b[d]) {
1786       if (d < input_sparse_dim) sparse_dims_to_sum_size ++;
1787       else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim);
1788     }
1789     else {
1790       if (d < input_sparse_dim) sparse_dims_to_keep_v.emplace_back(d);
1791     }
1792   }
1793 
1794   const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size);
1795   const bool sum_dense_dim = (!dense_dims_to_sum_v.empty());
1796   const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0);
1797 
1798   if (sum_all_sparse_dim) {
1799     TORCH_CHECK(!grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be dense since all sparse dims are summed");
1800     auto grad_input_values = grad_;
1801     auto expand_size = input_values.sizes().vec();
1802     if (sum_dense_dim) {
1803       auto dense_expand_size = std::vector<int64_t>(expand_size);
1804       dense_expand_size.erase(dense_expand_size.begin());
1805       AT_ASSERT(dense_expand_size.size() == static_cast<size_t>(input_values.dim() - 1));
1806       for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1);  // -1 since grad has no nnz dim
1807       grad_input_values = grad_input_values.expand(dense_expand_size);
1808     }
1809     grad_input_values = grad_input_values.expand(expand_size).clone(at::MemoryFormat::Contiguous);
1810     bool grad_is_coalesced = input.is_coalesced();
1811     return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(at::MemoryFormat::Contiguous), grad_input_values, input.options().dtype(grad_.dtype()), grad_is_coalesced); // convert to grad dtype
1812   }
1813   else {
1814     TORCH_CHECK(grad_.is_sparse(), "_sparse_sum_backward_cpu: expected grad_ Tensor to be sparse, but got dense");
1815     auto grad = grad_.coalesce();
1816     Tensor grad_indices = grad._indices();
1817     Tensor grad_values = grad._values();
1818     const int64_t grad_sparse_dim = grad.sparse_dim();
1819     const int64_t grad_nnz = grad._nnz();
1820 
1821     Tensor grad_values_expand = grad_values;
1822     if (sum_dense_dim) {
1823       auto expand_size = input_values.sizes().vec();
1824       if (sum_sparse_dim) expand_size[0] = grad_values.size(0);
1825       for (auto d : dense_dims_to_sum_v) grad_values_expand = grad_values_expand.unsqueeze(d);
1826       grad_values_expand = grad_values_expand.expand(expand_size).clone(at::MemoryFormat::Contiguous);
1827     }
1828 
1829     Tensor grad_input_values;
1830     if (sum_sparse_dim) {
1831       // see NOTE [ sparse.sum() backward ]
1832       grad_input_values = at::zeros_like(input_values, grad_values.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1833 
1834       // get flatten indices for grad and input
1835       auto grad_sparse_dim_to_keep_v = std::vector<int64_t>(grad_sparse_dim);
1836       std::iota(grad_sparse_dim_to_keep_v.begin(), grad_sparse_dim_to_keep_v.end(), 0);
1837 
1838       auto grad_indices_1D = flatten_indices_by_dims(grad_indices, grad.sizes(), grad_sparse_dim_to_keep_v); // flatten indices on all sparse_dim of grad, output indices is coalesced and sorted
1839       auto grad_indices_1D_accessor = grad_indices_1D.accessor<int64_t, 1>();
1840       auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v);
1841       auto input_indices_1D_accessor = input_indices_1D.accessor<int64_t, 1>();
1842 
1843       const auto copy_iter = TensorIteratorConfig()
1844         .add_output(grad_input_values)
1845         .add_input(grad_values_expand)
1846         .resize_outputs(false)
1847         .declare_static_shape(grad_values_expand.sizes(), /*squash_dims=*/0)
1848         .build();
1849       const auto device_type = kCPU;
1850 
1851       const auto gIv_data = reinterpret_cast<char*>(grad_input_values.data_ptr());
1852       const auto gOv_data = reinterpret_cast<char*>(grad_values_expand.data_ptr());
1853       const auto gIv_stride = (grad_input_values.strides()[0] *
1854                                grad_input_values.element_size());
1855       const auto gOv_stride = (grad_values_expand.strides()[0] *
1856                                grad_values_expand.element_size());
1857 
1858       // binary search to find matching indices
1859       at::parallel_for(0, input_nnz, 0, [&](int64_t start, int64_t end) {
1860         TensorIterator copy_iter_local(copy_iter);
1861 
1862         for (auto i: c10::irange(start, end)) {
1863           int64_t input_idx = input_indices_1D_accessor[i];
1864           int64_t l = 0, r = grad_nnz - 1;
1865           while (l <= r) {
1866             int64_t m = l + (r - l) / 2;
1867             if (grad_indices_1D_accessor[m] == input_idx) {
1868               // grad_input_values[i].copy_(grad_values_expand[m])
1869               copy_iter_local.unsafe_replace_operand(0, gIv_data + i * gIv_stride);
1870               copy_iter_local.unsafe_replace_operand(1, gOv_data + m * gOv_stride);
1871               copy_stub(device_type, copy_iter_local, /*non_blocking=*/false);
1872               break;
1873             }
1874             if (grad_indices_1D_accessor[m] < input_idx) {
1875               l = m + 1;
1876             }
1877             else {
1878               r = m - 1;
1879             }
1880           }
1881         }
1882       });
1883     }
1884     else {
1885       grad_input_values = grad_values_expand;
1886     }
1887     bool grad_is_coalesced = input.is_coalesced();
1888     return at::_sparse_coo_tensor_with_dims_and_tensors(input_sparse_dim, input_dense_dim, input_sizes, input_indices.clone(at::MemoryFormat::Contiguous), grad_input_values, grad.options(), grad_is_coalesced);
1889   }
1890 }
1891 
any_sparse(const Tensor & self)1892 Tensor any_sparse(const Tensor& self) {
1893   TORCH_INTERNAL_ASSERT(self.is_sparse());
1894 
1895   return at::any(self._values());
1896 }
1897 
bmm_sparse_cpu(const SparseTensor & self,const Tensor & mat2)1898 Tensor bmm_sparse_cpu(const SparseTensor& self, const Tensor& mat2) {
1899   Tensor result = at::empty({}, mat2.options());
1900   return bmm_out_sparse_cpu(self, mat2, result);
1901 }
1902 
1903 // Search a sorted strided array for the rightmost instance of a value.
1904 // Array must be sorted from lowest to highest.
1905 // Returns the index of the found element.
1906 // Returns by reference `found`, true if search value was found, false otherwise
1907 template<typename scalar_t>
binary_search_strided_rightmost(scalar_t search_val,TensorAccessor<scalar_t,1> & sorted_arr_accessor,int64_t sorted_arr_begin_idx,int64_t length,bool * found)1908 scalar_t binary_search_strided_rightmost(scalar_t search_val, TensorAccessor<scalar_t, 1>& sorted_arr_accessor, int64_t sorted_arr_begin_idx, int64_t length, bool* found) {
1909   if (length == 0) {
1910     *found = false;
1911     return -1;
1912   }
1913 
1914   int64_t left_ind = 0;
1915   int64_t right_ind = length - 1;
1916   // This value should be overwritten in the loop so we use
1917   // a destructive initial value to ensure disaster if that
1918   // turns out not to be the case.
1919   int64_t mid_ind = std::numeric_limits<int64_t>::max();
1920   bool done_searching = false;
1921 
1922   while (!done_searching) {
1923     mid_ind = left_ind + (right_ind - left_ind) / 2;
1924     scalar_t mid_val = sorted_arr_accessor[sorted_arr_begin_idx + mid_ind];
1925 
1926     if (mid_val > search_val) {
1927       right_ind = mid_ind-1;
1928     } else if((mid_val == search_val) && (
1929       (mid_ind == length - 1) || (sorted_arr_accessor[sorted_arr_begin_idx + mid_ind + 1] != search_val)
1930     )) {
1931       done_searching = true;
1932       *found = true;
1933     } else {
1934       left_ind = mid_ind+1;
1935     }
1936 
1937     if (left_ind > right_ind) {
1938       done_searching = true;
1939       *found = false;
1940       mid_ind = -1;
1941     }
1942   }
1943 
1944   return mid_ind;
1945 }
1946 
bmm_out_sparse_cpu(const SparseTensor & self,const Tensor & mat2,Tensor & result)1947 Tensor& bmm_out_sparse_cpu(const SparseTensor& self, const Tensor& mat2, Tensor& result) {
1948   TORCH_CHECK(!mat2.is_sparse(), "bmm_sparse: Tensor 'mat2' must be dense");
1949 
1950   TORCH_CHECK(self.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self.dense_dim());
1951   TORCH_CHECK(self.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self.sparse_dim());
1952   TORCH_CHECK(mat2.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2.dim());
1953 
1954   TORCH_CHECK(self.size(0) == mat2.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match");
1955   TORCH_CHECK(self.size(2) == mat2.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match");
1956 
1957   result.resize_({self.size(0), self.size(1), mat2.size(2)});
1958 
1959   if (self._nnz() == 0) {
1960     result.zero_();
1961     return result;
1962   }
1963 
1964   // First need to coalesce to get all of the first dimension indices
1965   // in order since we'll be sending each matrix into the MM operation
1966   SparseTensor self_coalesced = self.coalesce();
1967 
1968   int64_t nnz =        self_coalesced._nnz();
1969   Tensor indices = self_coalesced._indices();
1970   Tensor values =      self_coalesced._values();
1971 
1972   Tensor indices_dim0 = indices[0];
1973   auto indices_dim0_accessor = indices_dim0.accessor<int64_t, 1>();
1974   Tensor indices_dim1_dim2 = indices.slice(0, 1, 3);
1975 
1976   int64_t dim_i = self_coalesced.size(1);
1977   int64_t dim_j = self_coalesced.size(2);
1978   int64_t dim_k = mat2.size(2);
1979 
1980   Scalar beta = 0;
1981   Tensor t_dummy;
1982   Scalar alpha = 1;
1983 
1984   int64_t mat_el_begin_idx = 0;
1985 
1986   int64_t num_matrices = self_coalesced.size(0);
1987 
1988   // Iterate through each set of 2D matrices within the 3D
1989   // tensor inputs, performing a matrix multiply with each one.
1990   int64_t start_mat_num = indices_dim0_accessor[0];
1991   AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
1992     values.scalar_type(), "bmm_sparse_dense", [&] {
1993       for (int64_t cur_mat_num = 0;
1994         (cur_mat_num < num_matrices);
1995         cur_mat_num++
1996       ) {
1997         // If there are sparse matrices at the beginning or end that
1998         // have all zero elements, we need to zero out the result matrix.
1999         if ((cur_mat_num < start_mat_num) || (mat_el_begin_idx >= nnz)) {
2000           result[cur_mat_num].zero_();
2001           continue;
2002         }
2003 
2004         // Search for the range of sparse tensor elements that
2005         // correspond to the current matrix number. We already know
2006         // where the current matrix begins, so we just need to find
2007         // the end. The search excludes everything to the left of
2008         // the starting point, for best performance
2009         bool mat_end_found;
2010         int64_t mat_el_end_idx = binary_search_strided_rightmost(
2011           cur_mat_num,
2012           indices_dim0_accessor,
2013           mat_el_begin_idx,
2014           nnz-mat_el_begin_idx,
2015           &mat_end_found
2016         ) + mat_el_begin_idx;
2017 
2018         if (mat_end_found) {
2019           mat_el_end_idx++;
2020 
2021           // Create tensors to view just the current set of matrices
2022           const Tensor dense_matrix = mat2[cur_mat_num];
2023           Tensor result_matrix = result[cur_mat_num];
2024           Tensor sparse_indices = indices_dim1_dim2.slice(1, mat_el_begin_idx, mat_el_end_idx);
2025           Tensor sparse_values = values.slice(0, mat_el_begin_idx, mat_el_end_idx);
2026           int64_t sparse_nnz = mat_el_end_idx - mat_el_begin_idx;
2027 
2028 
2029           s_addmm_out_sparse_dense_worker<scalar_t>(
2030             sparse_nnz,
2031             dim_i, dim_j, dim_k,
2032             result_matrix,
2033             beta, t_dummy, alpha,
2034             sparse_indices, sparse_values,
2035             dense_matrix
2036           );
2037           mat_el_begin_idx = mat_el_end_idx;
2038 
2039         // If no elements for this sparse matrix are found, then
2040         // it's a zero matrix and we need to zero out the result
2041         } else {
2042           result[cur_mat_num].zero_();
2043         }
2044       }
2045     }
2046   );
2047   return result;
2048 }
2049 
conj_physical_out_sparse(const Tensor & input,Tensor & result)2050 Tensor& conj_physical_out_sparse(const Tensor& input, Tensor& result) {
2051   TORCH_INTERNAL_ASSERT(input.is_sparse());
2052   if (!is_same_tensor(result, input)) {
2053     copy_sparse_to_sparse_(result, input);
2054   }
2055   if (!input.is_complex()) {
2056     return result;
2057   }
2058   Tensor result_values = result._values();
2059   at::conj_physical_out(result_values, input._values());
2060   return result;
2061 }
2062 
2063 } // namespace at::native
2064