xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ReduceOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/ReduceOps.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/Dispatch_v2.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/WrapDimUtils.h>
10 #include <ATen/WrapDimUtilsMulti.h>
11 #include <ATen/TensorIterator.h>
12 #include <ATen/TensorOperators.h>
13 #include <ATen/NamedTensorUtils.h>
14 #include <ATen/native/ReduceOpsUtils.h>
15 #include <ATen/native/Resize.h>
16 #include <ATen/native/TensorDimApply.h>
17 #include <ATen/core/grad_mode.h>
18 #include <ATen/TensorSubclassLikeUtils.h>
19 
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/_cummax_helper.h>
25 #include <ATen/ops/_cummax_helper_native.h>
26 #include <ATen/ops/_cummin_helper.h>
27 #include <ATen/ops/_cummin_helper_native.h>
28 #include <ATen/ops/_is_all_true_native.h>
29 #include <ATen/ops/_is_any_true_native.h>
30 #include <ATen/ops/_logcumsumexp.h>
31 #include <ATen/ops/_logcumsumexp_native.h>
32 #include <ATen/ops/_sparse_csr_sum.h>
33 #include <ATen/ops/_sparse_sum.h>
34 #include <ATen/ops/_sparse_sum_native.h>
35 #include <ATen/ops/_to_copy.h>
36 #include <ATen/ops/add.h>
37 #include <ATen/ops/all_meta.h>
38 #include <ATen/ops/all_native.h>
39 #include <ATen/ops/amax.h>
40 #include <ATen/ops/amax_meta.h>
41 #include <ATen/ops/amax_native.h>
42 #include <ATen/ops/amin_meta.h>
43 #include <ATen/ops/amin_native.h>
44 #include <ATen/ops/aminmax_meta.h>
45 #include <ATen/ops/aminmax_native.h>
46 #include <ATen/ops/any_meta.h>
47 #include <ATen/ops/any_native.h>
48 #include <ATen/ops/argmax_meta.h>
49 #include <ATen/ops/argmax_native.h>
50 #include <ATen/ops/argmin_meta.h>
51 #include <ATen/ops/argmin_native.h>
52 #include <ATen/ops/cat.h>
53 #include <ATen/ops/complex.h>
54 #include <ATen/ops/cummax.h>
55 #include <ATen/ops/cummax_native.h>
56 #include <ATen/ops/cummaxmin_backward_native.h>
57 #include <ATen/ops/cummin.h>
58 #include <ATen/ops/cummin_native.h>
59 #include <ATen/ops/cumprod.h>
60 #include <ATen/ops/cumprod_backward_native.h>
61 #include <ATen/ops/cumprod_meta.h>
62 #include <ATen/ops/cumprod_native.h>
63 #include <ATen/ops/cumsum.h>
64 #include <ATen/ops/cumsum_meta.h>
65 #include <ATen/ops/cumsum_native.h>
66 #include <ATen/ops/diff_native.h>
67 #include <ATen/ops/dist_native.h>
68 #include <ATen/ops/empty.h>
69 #include <ATen/ops/empty_like.h>
70 #include <ATen/ops/equal_native.h>
71 #include <ATen/ops/exp.h>
72 #include <ATen/ops/gather.h>
73 #include <ATen/ops/gradient_native.h>
74 #include <ATen/ops/imag.h>
75 #include <ATen/ops/isnan_native.h>
76 #include <ATen/ops/linalg_vector_norm.h>
77 #include <ATen/ops/logcumsumexp.h>
78 #include <ATen/ops/logcumsumexp_native.h>
79 #include <ATen/ops/logical_xor.h>
80 #include <ATen/ops/logsumexp.h>
81 #include <ATen/ops/logsumexp_native.h>
82 #include <ATen/ops/mean.h>
83 #include <ATen/ops/mean_meta.h>
84 #include <ATen/ops/mean_native.h>
85 #include <ATen/ops/nanmean_native.h>
86 #include <ATen/ops/nansum.h>
87 #include <ATen/ops/nansum_native.h>
88 #include <ATen/ops/narrow.h>
89 #include <ATen/ops/native_norm.h>
90 #include <ATen/ops/ne.h>
91 #include <ATen/ops/norm.h>
92 #include <ATen/ops/norm_meta.h>
93 #include <ATen/ops/norm_native.h>
94 #include <ATen/ops/ones.h>
95 #include <ATen/ops/prod.h>
96 #include <ATen/ops/prod_meta.h>
97 #include <ATen/ops/prod_native.h>
98 #include <ATen/ops/real.h>
99 #include <ATen/ops/slice.h>
100 #include <ATen/ops/special_logsumexp_native.h>
101 #include <ATen/ops/sqrt.h>
102 #include <ATen/ops/squeeze.h>
103 #include <ATen/ops/stack.h>
104 #include <ATen/ops/std.h>
105 #include <ATen/ops/std_mean.h>
106 #include <ATen/ops/std_mean_native.h>
107 #include <ATen/ops/std_native.h>
108 #include <ATen/ops/sub.h>
109 #include <ATen/ops/sum.h>
110 #include <ATen/ops/sum_meta.h>
111 #include <ATen/ops/sum_native.h>
112 #include <ATen/ops/trace_native.h>
113 #include <ATen/ops/value_selecting_reduction_backward_native.h>
114 #include <ATen/ops/var.h>
115 #include <ATen/ops/var_mean.h>
116 #include <ATen/ops/var_mean_native.h>
117 #include <ATen/ops/var_native.h>
118 #include <ATen/ops/zeros.h>
119 #include <ATen/ops/zeros_like.h>
120 #endif
121 
122 #include <c10/util/irange.h>
123 #include <c10/util/SmallBuffer.h>
124 
125 #include <algorithm>
126 #include <cmath>
127 #include <functional>
128 #include <limits>
129 #include <numeric>
130 #include <type_traits>
131 #include <utility>
132 #include <vector>
133 
134 namespace at::meta {
135 
infer_dtype_from_optional(const Tensor & self,const std::optional<ScalarType> & opt_dtype,const Tensor & result)136 static ScalarType infer_dtype_from_optional(
137     const Tensor& self,
138     const std::optional<ScalarType>& opt_dtype,
139     const Tensor& result) {
140   // 'opt_dtype' has the priority for both cases.
141   if (result.defined()) {
142     // Otherwise, get the result type, if defined.
143     return opt_dtype.value_or(result.scalar_type());
144   } else {
145     // Last case is to get the self type.
146     // If the self type is an integer, we promote it to kLong.
147     return at::native::get_dtype_from_self(self, opt_dtype, true);
148   }
149 }
150 
optional_to_arrayref(const std::optional<int64_t> & opt)151 static IntArrayRef optional_to_arrayref(const std::optional<int64_t>& opt) {
152   return opt.has_value() ? opt.value() : IntArrayRef{};
153 }
154 
get_result_or_bytebool_dtype(const Tensor & self,const Tensor & result)155 static ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result) {
156   // Refer [all, any : uint8 compatibility]
157   if (result.defined()) {
158     return result.scalar_type();
159   } else {
160     return (self.scalar_type() == kByte) ? kByte : kBool;
161   }
162 }
163 
check_result_is_bytebool(const char * name,const Tensor & self,const Tensor & result)164 static void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) {
165   if (result.defined()) {
166     // Refer [all, any : uint8 compatibility]
167     TORCH_CHECK(
168         result.scalar_type() == ScalarType::Bool ||
169             result.scalar_type() == ScalarType::Byte,
170         name, " only supports bool tensor for result, got: ",
171         result.scalar_type());
172   }
173 }
174 
175 // Note [all, any : uint8 compatibility]:
176 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177 // For NumPy compatibility, `all` and `any` return
178 // Tensor of dtype `bool`. However for compatibility reason,
179 // for `uint8`, they return Tensor of same dtype `uint8`.
180 // Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
allany_meta(impl::MetaBase & meta,const char * name,const Tensor & self,OptionalIntArrayRef dims,bool keepdim)181 static void allany_meta(
182     impl::MetaBase& meta,
183     const char* name,
184     const Tensor& self,
185     OptionalIntArrayRef dims,
186     bool keepdim) {
187   const auto& result = meta.maybe_get_output();
188   check_result_is_bytebool(name, self, result);
189   auto out_dtype = get_result_or_bytebool_dtype(self, result);
190   resize_reduction(meta, self, dims, keepdim, out_dtype, /*allow_empty_dims=*/true);
191 }
192 
TORCH_META_FUNC2(all,dim)193 TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
194   allany_meta(*this, "all", self, dim, keepdim);
195 }
196 
TORCH_META_FUNC2(all,dims)197 TORCH_META_FUNC2(all, dims)(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) {
198   allany_meta(*this, "all", self, dim, keepdim);
199 }
200 
TORCH_META_FUNC(all)201 TORCH_META_FUNC(all)(const Tensor& self) {
202   allany_meta(*this, "all", self, {}, false);
203 }
204 
TORCH_META_FUNC2(any,dim)205 TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
206   allany_meta(*this, "any", self, dim, keepdim);
207 }
208 
TORCH_META_FUNC2(any,dims)209 TORCH_META_FUNC2(any, dims)(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) {
210   allany_meta(*this, "any", self, dim, keepdim);
211 }
212 
TORCH_META_FUNC(any)213 TORCH_META_FUNC(any)(const Tensor& self) {
214   allany_meta(*this, "any", self, {}, false);
215 }
216 
check_argmax_argmin(const char * name,const Tensor & self,const std::optional<int64_t> & dim)217 static void check_argmax_argmin(
218     const char* name,
219     const Tensor& self,
220     const std::optional<int64_t>& dim) {
221   if (dim.has_value()) {
222     auto dim_ = maybe_wrap_dim(dim.value(), self.dim());
223     native::zero_numel_check_dims(self, dim_, name);
224   } else {
225     TORCH_CHECK_INDEX(
226         self.numel() != 0,
227         name, ": Expected reduction dim to be specified for input.numel() == 0.");
228   }
229 }
230 
TORCH_META_FUNC(argmax)231 TORCH_META_FUNC(argmax)
232 (const Tensor& self, std::optional<int64_t> dim, bool keepdim) {
233   check_argmax_argmin("argmax()", self, dim);
234   resize_reduction(*this, self, optional_to_arrayref(dim), keepdim, kLong);
235 }
236 
TORCH_META_FUNC(argmin)237 TORCH_META_FUNC(argmin)
238 (const Tensor& self, std::optional<int64_t> dim, bool keepdim) {
239   check_argmax_argmin("argmin()", self, dim);
240   resize_reduction(*this, self, optional_to_arrayref(dim), keepdim, kLong);
241 }
242 
meta_func_cum_ops(impl::MetaBase & meta,const char * name,const Tensor & self,int64_t dim,std::optional<ScalarType> dtype)243 static void meta_func_cum_ops(
244     impl::MetaBase& meta,
245     const char* name,
246     const Tensor& self,
247     int64_t dim,
248     std::optional<ScalarType> dtype) {
249   // Checking whether 'dim' is valid.
250   maybe_wrap_dim(dim, self.dim());
251 
252   const auto& result = meta.maybe_get_output();
253   ScalarType out_dtype{};
254 
255   if (result.defined()) {
256     out_dtype = dtype.value_or(result.scalar_type());
257   } else {
258     auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
259     out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
260   }
261 
262   meta.set_output_raw_strided(0, self.sizes(), {}, self.options().dtype(out_dtype));
263   namedinference::propagate_names(result, self);
264 }
265 
TORCH_META_FUNC(cumsum)266 TORCH_META_FUNC(cumsum)
267 (const Tensor& self, int64_t dim, std::optional<ScalarType> dtype) {
268   meta_func_cum_ops(*this, "cumsum", self, dim, dtype);
269 }
270 
TORCH_META_FUNC(cumprod)271 TORCH_META_FUNC(cumprod)
272 (const Tensor& self, int64_t dim, std::optional<ScalarType> dtype) {
273   meta_func_cum_ops(*this, "cumprod", self, dim, dtype);
274 }
275 
TORCH_META_FUNC2(sum,dim_IntList)276 TORCH_META_FUNC2(sum, dim_IntList)
277 (const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
278   auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
279   resize_reduction(*this, self, opt_dim, keepdim, out_dtype);
280 }
281 
TORCH_META_FUNC2(prod,dim_int)282 TORCH_META_FUNC2(prod, dim_int)
283 (const Tensor& self,
284  int64_t dim,
285  bool keepdim,
286  std::optional<ScalarType> dtype) {
287   auto out_dtype = infer_dtype_from_optional(self, dtype, maybe_get_output());
288   resize_reduction(*this, self, dim, keepdim, out_dtype);
289 }
290 
TORCH_META_FUNC2(mean,dim)291 TORCH_META_FUNC2(mean, dim)
292 (const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
293   auto in_dtype = at::native::get_dtype_from_self(self, opt_dtype, true);
294 
295   if (!at::isFloatingType(in_dtype) && !at::isComplexType(in_dtype)) {
296     std::string what = "Input";
297     std::string dtype = toString(self.scalar_type());
298 
299     if (opt_dtype.has_value()) {
300       what = "Optional";
301       dtype = toString(opt_dtype.value());
302     }
303 
304     TORCH_CHECK(
305         false,
306         "mean(): could not infer output dtype. ",
307         what, " dtype must be either a floating point or complex dtype. ",
308         "Got: ", dtype);
309   }
310 
311   auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
312   resize_reduction(*this, self, opt_dim, keepdim, out_dtype);
313 }
314 
get_result_or_self_value_dtype(const Tensor & self,const Tensor & result,const std::optional<ScalarType> & dtype)315 static ScalarType get_result_or_self_value_dtype(
316     const Tensor& self,
317     const Tensor& result,
318     const std::optional<ScalarType>& dtype) {
319   if (result.defined()) {
320     return result.scalar_type();
321   } else {
322     return dtype.value_or(toRealValueType(self.scalar_type()));
323   }
324 }
325 
TORCH_META_FUNC2(norm,ScalarOpt_dim)326 TORCH_META_FUNC2(norm, ScalarOpt_dim)
327 (const Tensor& self, const OptionalScalarRef p, IntArrayRef dim, bool keepdim) {
328   TORCH_CHECK(
329       at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
330       "norm(): input dtype should be either floating point or complex. "
331       "Got ", self.scalar_type(), " instead.");
332 
333   auto out_dtype = get_result_or_self_value_dtype(self, maybe_get_output(), std::nullopt);
334   resize_reduction(*this, self, dim, keepdim, out_dtype);
335 }
336 
TORCH_META_FUNC2(norm,ScalarOpt_dim_dtype)337 TORCH_META_FUNC2(norm, ScalarOpt_dim_dtype)
338 (const Tensor& self,
339  const OptionalScalarRef p,
340  IntArrayRef dim,
341  bool keepdim,
342  ScalarType dtype) {
343   TORCH_CHECK(
344       at::isFloatingType(dtype) || at::isComplexType(dtype),
345       "norm(): the desired output dtype should be either floating point or complex. "
346       "Got ", dtype, " instead.");
347 
348   auto out_dtype = get_result_or_self_value_dtype(self, maybe_get_output(), dtype);
349   resize_reduction(*this, self, dim, keepdim, out_dtype);
350 }
351 
TORCH_META_FUNC(aminmax)352 TORCH_META_FUNC(aminmax)
353 (const Tensor& self, std::optional<int64_t> dim_opt, bool keepdim) {
354   DimVector shape;
355   if (dim_opt.has_value()) {
356     auto dim = maybe_wrap_dim(dim_opt.value(), self.ndimension());
357     native::zero_numel_check_dims(self, dim, "aminmax");
358     shape = get_reduction_shape(self, dim, keepdim);
359   } else {
360     TORCH_CHECK(
361         self.numel() > 0,
362         "aminmax(): cannot compute aminmax over an empty dimension as the "
363         "operation has no identity.");
364     if (keepdim) {
365       shape = DimVector(self.ndimension(), 1);
366     }
367   }
368   const auto options = self.options();
369   this->set_output_raw_strided(0, shape, {}, options);
370   this->set_output_raw_strided(1, shape, {}, options);
371 }
372 
TORCH_META_FUNC(amax)373 TORCH_META_FUNC(amax)
374 (const Tensor& self, IntArrayRef dim, bool keepdim) {
375   auto maybe_result = maybe_get_output();
376   if (maybe_result.defined()) {
377     TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ",
378             self.scalar_type(), " for input's dtype and ",  maybe_result.scalar_type(), " for out's dtype.");
379   }
380   if (self.numel() == 0) {
381     at::native::zero_numel_check_dims(self, dim, "amax()");
382   }
383   const ScalarType& out_dtype = maybe_result.defined() ? maybe_result.scalar_type() : self.scalar_type();
384   resize_reduction(*this, self, dim, keepdim, out_dtype);
385 }
386 
TORCH_META_FUNC(amin)387 TORCH_META_FUNC(amin)
388 (const Tensor& self, IntArrayRef dim, bool keepdim) {
389   auto maybe_result = maybe_get_output();
390   if (maybe_result.defined()) {
391     TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ",
392                 self.scalar_type(), " for input's dtype and ",  maybe_result.scalar_type(), " for out's dtype.");
393   }
394   if (self.numel() == 0) {
395     at::native::zero_numel_check_dims(self, dim, "amin()");
396   }
397   const ScalarType& out_dtype = maybe_result.defined() ? maybe_result.scalar_type() : self.scalar_type();
398   resize_reduction(*this, self, dim, keepdim, out_dtype);
399 }
400 
401 } // namespace at::meta
402 
403 namespace at::native {
404 
405 DEFINE_DISPATCH(aminmax_stub);
406 DEFINE_DISPATCH(aminmax_allreduce_stub);
407 
TORCH_IMPL_FUNC(aminmax_out)408 TORCH_IMPL_FUNC(aminmax_out)
409 (const Tensor& self,
410  std::optional<int64_t> dim_opt,
411  bool keepdim,
412  const Tensor& min,
413  const Tensor& max) {
414   auto mutable_min = const_cast<Tensor&>(min);
415   auto mutable_max = const_cast<Tensor&>(max);
416   if (dim_opt.has_value()) {
417     aminmax_stub(
418         self.device().type(),
419         self,
420         maybe_wrap_dim(dim_opt.value(), self.ndimension()),
421         keepdim,
422         mutable_min,
423         mutable_max);
424   } else {
425     aminmax_allreduce_stub(self.device().type(), self.contiguous(), mutable_min, mutable_max);
426   }
427 }
428 
429 DEFINE_DISPATCH(sum_stub);
430 DEFINE_DISPATCH(nansum_stub);
431 DEFINE_DISPATCH(std_var_stub);
432 DEFINE_DISPATCH(prod_stub);
433 DEFINE_DISPATCH(norm_stub);
434 DEFINE_DISPATCH(mean_stub);
435 DEFINE_DISPATCH(and_stub);
436 DEFINE_DISPATCH(or_stub);
437 DEFINE_DISPATCH(min_values_stub);
438 DEFINE_DISPATCH(max_values_stub);
439 DEFINE_DISPATCH(argmax_stub);
440 DEFINE_DISPATCH(argmin_stub);
441 DEFINE_DISPATCH(cumsum_stub);
442 DEFINE_DISPATCH(cumprod_stub);
443 DEFINE_DISPATCH(logcumsumexp_stub);
444 
_logcumsumexp_cpu(const Tensor & self,int64_t dim)445 Tensor _logcumsumexp_cpu(const Tensor& self, int64_t dim) {
446   Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
447   return _logcumsumexp_out_cpu(self, dim, result);
448 }
449 
_logcumsumexp_out_cpu(const Tensor & self,int64_t dim,Tensor & result)450 Tensor& _logcumsumexp_out_cpu(const Tensor& self, int64_t dim, Tensor& result) {
451   logcumsumexp_stub(self.device().type(), result, self, dim);
452   return result;
453 }
454 
logcumsumexp(const Tensor & self,int64_t dim)455 Tensor logcumsumexp(const Tensor& self, int64_t dim) {
456   auto result = [&]() {
457     NoNamesGuard guard;
458     return at::_logcumsumexp(self, dim);
459   }();
460   namedinference::propagate_names(result, self);
461   return result;
462 }
463 
logcumsumexp_out(const Tensor & self,int64_t dim,Tensor & result)464 Tensor& logcumsumexp_out(const Tensor& self, int64_t dim, Tensor& result) {
465   check_scalar_type_device_layout_equal(result, self);
466   {
467     NoNamesGuard guard;
468     at::_logcumsumexp_out(result, self.toType(result.scalar_type()), dim);
469   }
470   namedinference::propagate_names(result, self);
471   return result;
472 }
473 
474 template <class Stub>
impl_func_cum_ops(const Tensor & self,int64_t dim,const Tensor & result,Stub & stub)475 void impl_func_cum_ops(
476     const Tensor& self,
477     int64_t dim,
478     const Tensor& result,
479     Stub& stub) {
480   NoNamesGuard guard;
481   if (self.dim() == 0) {
482     result.fill_(self);
483   } else if (self.numel() == 0) {
484     result.zero_();
485   } else {
486     dim = maybe_wrap_dim(dim, self.dim());
487     stub(self.device().type(), result, self.to(result.scalar_type()), dim);
488   }
489 }
490 
TORCH_IMPL_FUNC(cumsum_out)491 TORCH_IMPL_FUNC(cumsum_out)
492 (const Tensor& self,
493  int64_t dim,
494  std::optional<ScalarType> dtype,
495  const Tensor& result) {
496   impl_func_cum_ops(self, dim, result, cumsum_stub);
497 }
498 
TORCH_IMPL_FUNC(cumprod_out)499 TORCH_IMPL_FUNC(cumprod_out)
500 (const Tensor& self,
501  int64_t dim,
502  std::optional<ScalarType> dtype,
503  const Tensor& result) {
504   impl_func_cum_ops(self, dim, result, cumprod_stub);
505 }
506 
reversed_cumsum(const Tensor & w,int64_t dim)507 static Tensor reversed_cumsum(const Tensor& w, int64_t dim) {
508   return w.flip(dim).cumsum(dim).flip(dim);
509 }
510 
cumprod_backward(const Tensor & grad,const Tensor & input,int64_t dim,const Tensor & output)511 Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, const Tensor& output) {
512   /*
513     We show here how to derive an O(n) gradient formula for
514     arbitrary inputs. It follows via a basic application of the
515     chain rule together with a number of observations for different
516     cases. We assume that x is an n-dimensional vector and y = cumprod(x).
517     In the actual implementation we will need to play a bit with masks
518     to be able to implement the formulas deduced here for tensors.
519 
520     We will first deduce the formula for the case when
521     x[i] != 0 for 1 <= i <= n.
522 
523     For F : R^n -> R the cost function (we will look at the complex case later),
524     we have
525 
526     dF / dx_k = sum_j (dF / dy_j) * (dy_j / dx_k)   (1)
527 
528     The term dF / dy_j is just grad_output[j] (assuming again
529     everything is one-dimensional).
530 
531     The term (dy_j / dx_k) is easily seen to be
532 
533     if j >= k
534       dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i
535     else:
536       dy_j / dx_k = 0
537 
538     Note that the indicator (j>=k) can be taken out
539     by replacing the sum in (1) with a sum from
540     k <= j <= n.
541 
542     Thus,
543     dF / dx_k = sum_{k <= j <= n} grad_output[j] * (dy_j / dx_k)
544 
545     with
546     dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i     (2)
547 
548     Note that this last term is just the cumulative product
549     with k omitted. Thus, if x_k (the input) is nonzero, we can
550     just express this as
551 
552     dy_j / dx_k = (prod_{1 <= i <= j} x_i) / x_k
553                 = y_j / x_k
554 
555     So therefore,
556 
557     dF / dx_k = sum_{k <= j <= n} grad_output[j] * y_j / x_k
558 
559     This formula just makes sense when input[i] != 0 for every i.
560 
561     Assume now that there exists at least a zero in the input.
562     Denote by z1 the first element 1 <= z1 <= n with input[z1] = 0
563     and z2 the second element z1 < z2 <= n with input[z2] = 0,
564     (or z2 = n if there is just one zero in input)
565 
566     We have three cases.
567 
568     k > z1:
569     Looking at (2), we see that dy_j / dx_k = 0, for j >= k, as these terms
570     all include a x_{z1} which is zero. As such, dF / dx_k = 0 in this case
571 
572     k < z1:
573     Reasoning as in the previous case, we see that for these elements we have that
574 
575     dF / dx_k = sum_{k <= j < z1} grad_output[j] * (dy_j / dx_k)
576 
577     as the terms of the sum for j in z1 <= j <= n are all zero
578 
579     k = z1:
580     Similar to the case k < z1, we have that
581 
582     dF / dx_z1 = sum_{z1 <= j < z2} grad_output[j] * (dy_j / dx_z1)
583 
584     This case has a subtlety though. To compute (dy_j / dx_z1), we cannot use the formula
585 
586     dy_j / dx_z1 = y_j / x_z1
587 
588     as, y_j = x_z1 = 0 for j >= z1. We need to compute it with the formula for its derivative,
589     that is:
590 
591     dy_j / dx_z1 = prod(x[:z1]) * (grad_output[z1] + sum(grad_output[z1+1:z2] * cumprod(x[z1+1:z2])))
592 
593     When the inputs are complex, this is map is holomorphic. As such, to compute
594     its backwards is just the conjugate of the usual backwards. This simplifies to
595     conjugating the input. We may also reuse the output as, since the map is holomorphic,
596     cumprod(input.conj()) = cumprod(input).conj()
597   */
598 
599   if (input.sym_numel() <= 1) {
600     return grad;
601   }
602   dim = at::maybe_wrap_dim(dim, input.dim());
603   const int64_t dim_size = input.sym_sizes()[dim].guard_int(__FILE__, __LINE__);
604   if (dim_size == 1) {
605     return grad;
606   }
607 
608   // To enable complex support.
609   // From this line on `input_conj` and output_conj`
610   // are interchangeable with `input` and `output`.
611   auto input_conj = input.conj();
612   auto output_conj = output.conj();
613 
614   // For Composite Compliance, we always choose the slower but composite compliant path.
615   bool are_inputs_tensors_sublcass = areAnyTensorSubclassLike({input, grad, output});
616 
617   const auto w = output_conj * grad;
618   const auto is_zero = input == 0;
619   if (!are_inputs_tensors_sublcass) {
620     if (is_zero.any().item<uint8_t>() == 0) {
621       return reversed_cumsum(w, dim).div(input_conj);
622     }
623   }
624 
625   // If we are not computing a second order gradient, we can use an
626   // O(n) implementation. The derivative of this implementation is _not_
627   // the second derivative of cumprod. As such, we fallback to a less efficient
628   // O(n^2) implementation when at::GradMode::is_enabled().
629   if (!at::GradMode::is_enabled() && !are_inputs_tensors_sublcass) {
630     // n.b. This could probably be implemented much faster with a kernel
631 
632     // From here on we need to use some mask gymnastics to
633     // account for the tensorial dimensions
634     // We do a cumsum of the zeros along the dimension.
635     // For a vector is_zero = [False, True, False, True, False]
636     // we would have cumsum = [0, 1, 1, 2, 2]
637     // As such we have (in python code for simplicity)
638     // The mask for the range [0, z1):
639     // cumsum == 0
640     // The indices of the first zero z1 and zeros when
641     // there is no first zero:
642     // indices = (cumsum == 1).max(dim, keepdim=True).indices
643     // The mask for the first zero:
644     // zeros_like(indices).scatter_(dim, indices, 1.) & cumsum == 1
645     // Note that the logic_and with cumsum == 1 accounts
646     // for the case when there is no first zero
647     Tensor grad_input = at::zeros_symint(input.sym_sizes(), grad.options());
648     const auto cumsum = is_zero.cumsum(dim);
649 
650     // case k < z1
651     // select everything before the first zero [0, z1)
652     auto mask = cumsum == 0;
653     // equiv to grad_input[mask] = deriv[grad]
654     grad_input.masked_scatter_(mask,
655         reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input_conj).masked_select(mask));
656     // select everything from the first zero to the second zero [z1, z2)
657     mask = cumsum == 1;
658 
659     // case k = z1
660     // We start by select the first zero [z1]
661     // We locate the indices of the first zero using the max function
662     // We then go from the indices to a mask index_fill_
663     // When there is no zero in the slice, max will return the index 0.
664     // To account for this, we need to do an intersection with mask,
665     // which is true in the range [z1, z2)
666     const auto first_zero_index = std::get<1>(mask.max(dim, /*keepdim*/ true));
667     const auto first_zero_mask = at::zeros_like(mask)
668                                   .scatter_(dim, first_zero_index, /*src*/ 1)
669                                   .logical_and_(mask);
670 
671     // select everything between the first zero and the second zero (z1, z2)
672     mask &= ~first_zero_mask;
673     // here we compute
674     // dy_j / dx_z1 = sum(cumprod(input[z1+1:z2] * grad[z1+1:z2])) * prod(output[z1-1])
675     // relu_() necessary as gather does not support negative indices
676     // finally, we do grad_input[z1] = dy_j / dx_z1
677     grad_input.masked_scatter_(first_zero_mask,
678                                input_conj.masked_fill(~mask, 1.).cumprod(dim)
679                                     .mul_(grad.masked_fill(cumsum != 1, 0.))
680                                     .sum(dim, /*keepdim*/true)
681                                     .mul_(at::gather(output_conj, dim, (first_zero_index - 1).relu_())
682                                           .masked_fill_(first_zero_index == 0, 1.))
683                                     .masked_select(first_zero_mask));
684     return grad_input;
685   } else { // GradMode::enabled()
686     /*
687     If the input is nonzero, we need to calculate the dy_j / dx_k
688     by using the formula (2), called in the code omitted_products.
689 
690     The way the code calculates it is simply by noting that
691 
692     prod_{1 <= i <= j, i != k} x_i
693         = (prod_{1 <= i <= k} x_i) * (prod_{k + 1 <= i <= j} x_i)
694 
695     the first term is calculated as prods_until_k, which since
696     doesn't depend in j is easy to vectorize.
697 
698     The second term (indexed by j) is the cumulative product of
699     x_{k+1}, x_{k+2}, ..., x_n, and it's named in the code
700     prods_from_k_pkus_1, and it's calculated as a cumprod.
701 
702     In order to vectorize this properly, we need to add to
703     omitted_products the dimensions where k > j, and therefore
704     dy_j / dx_k = 0, which is done right after the assert.
705     */
706 
707     Tensor grad_input;
708     // For Composite Compliance, we will use
709     // at::stack on the grad slices, hence the vector.
710     std::vector<Tensor> grad_inputs;
711     if (are_inputs_tensors_sublcass) {
712       grad_inputs.reserve(dim_size);
713     } else {
714       grad_input = at::zeros(input.sizes(), grad.options());
715     }
716     auto ones_size = input.sym_sizes().vec();
717     ones_size[dim] = 1;
718     const Tensor ones = at::ones({1}, grad.options()).expand_symint(ones_size);
719     Tensor prods_from_k_plus_1;
720     Tensor omitted_products;
721     for (const auto k : c10::irange(dim_size)) {
722       if (k == 0) {
723         prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k + 1), dim);
724         omitted_products = at::cat({ones, std::move(prods_from_k_plus_1)}, dim);
725       } else if (k == dim_size - 1) {
726         const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
727         omitted_products = prods_until_k;
728       } else {
729         const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
730         prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k+1), dim);
731         omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
732         omitted_products = at::cat({prods_until_k, omitted_products}, dim);
733       }
734 
735       // At this point omitted_products is the same size
736       // as input, except on the dimension dim where it's
737       // dim_size - k
738       TORCH_CHECK(omitted_products.sym_size(dim) == dim_size - k);
739 
740       auto grad_slice = at::sum(grad.slice(dim, k) * omitted_products, dim);
741       if (are_inputs_tensors_sublcass) {
742         grad_inputs.push_back(grad_slice);
743       } else {
744         grad_input.select(dim, k).copy_(grad_slice);
745       }
746     }
747 
748     return are_inputs_tensors_sublcass ? at::stack(grad_inputs, dim) : std::move(grad_input);
749   }
750 }
751 
752 // Implement std::is_nan<IntegralType> for MSVC.
753 namespace {
754 #ifdef _MSC_VER
755 template<typename T>
isnan_(T x)756 inline typename std::enable_if<std::is_integral<T>::value, bool>::type isnan_(T x) {
757   return false;
758 }
759 template<typename T>
isnan_(T x)760 inline typename std::enable_if<!std::is_integral<T>::value, bool>::type isnan_(T x) {
761   return std::isnan(x);
762 }
763 #else
764 template<typename T>
765 inline bool isnan_(T x) {
766   return std::isnan(x);
767 }
768 #endif
769 }
770 
771 template<typename T1, typename T2, typename Operation>
cummax_cummin_helper(const T1 * self_data,T1 * values_data,T2 * indices_data,int self_dim_size,int self_stride,int values_stride,int indices_stride)772 void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data,
773           int self_dim_size, int self_stride, int values_stride, int indices_stride) {
774       Operation op;
775       T1 out = c10::load(self_data);
776       int idx = 0;
777       for (const auto i : c10::irange(self_dim_size)) {
778         T1 curr_elem = c10::load(&self_data[i*self_stride]);
779         if(isnan_(curr_elem) || (!isnan_(out) && op(curr_elem, out))) {
780             out = curr_elem;
781             idx = i;
782         }
783         values_data[i*values_stride] = out;
784         indices_data[i*indices_stride] = idx;
785       }
786 }
787 
cummax_helper_cpu(const Tensor & self,Tensor & values,Tensor & indices,int64_t dim)788 void cummax_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
789   AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf,
790     self.scalar_type(), "cummax_cpu",
791     [&] {
792       at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::greater_equal<scalar_t>>);
793     });
794 }
795 
cummax_out(const Tensor & self,int64_t dim,Tensor & values,Tensor & indices)796 std::tuple<Tensor&, Tensor&> cummax_out(const Tensor& self, int64_t dim, Tensor& values, Tensor& indices) {
797   check_scalar_type_device_layout_equal(values, self);
798   check_scalar_type_device_layout_equal(indices, at::empty({0}, self.options().dtype(at::kLong)));
799   {
800     NoNamesGuard guard;
801     at::native::resize_output(values, self.sizes());
802     at::native::resize_output(indices, self.sizes());
803     if(self.dim() == 0) {
804       values.fill_(self);
805       indices.fill_(0);
806     } else if(self.numel() != 0) {
807       dim = maybe_wrap_dim(dim, self.dim());
808       at::_cummax_helper(self, values, indices, dim);
809     }
810   }
811   namedinference::propagate_names(values, self);
812   namedinference::propagate_names(indices, self);
813   return std::forward_as_tuple(values, indices);
814 }
815 
cummax(const Tensor & self,int64_t dim)816 std::tuple<Tensor, Tensor> cummax(const Tensor& self, int64_t dim) {
817   auto values = at::empty(self.sizes(), self.options());
818   auto indices = at::empty(self.sizes(), self.options().dtype(at::kLong));
819   at::cummax_out(values, indices, self, dim);
820   return std::make_tuple(values, indices);
821 }
822 
cummin_helper_cpu(const Tensor & self,Tensor & values,Tensor & indices,int64_t dim)823 void cummin_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
824   AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf,
825     self.scalar_type(), "cummin_cpu",
826     [&] {
827       at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::less_equal<scalar_t>>);
828     });
829 }
830 
cummin_out(const Tensor & self,int64_t dim,Tensor & values,Tensor & indices)831 std::tuple<Tensor&, Tensor&> cummin_out(const Tensor& self, int64_t dim, Tensor& values, Tensor& indices) {
832   check_scalar_type_device_layout_equal(values, self);
833   check_scalar_type_device_layout_equal(indices, at::empty({0}, self.options().dtype(at::kLong)));
834   {
835     NoNamesGuard guard;
836     at::native::resize_output(values, self.sizes());
837     at::native::resize_output(indices, self.sizes());
838     if(self.dim() == 0) {
839       values.fill_(self);
840       indices.fill_(0);
841     } else if(self.numel() != 0) {
842       dim = maybe_wrap_dim(dim, self.dim());
843       at::_cummin_helper(self, values, indices, dim);
844     }
845   }
846   namedinference::propagate_names(values, self);
847   namedinference::propagate_names(indices, self);
848   return std::forward_as_tuple(values, indices);
849 }
850 
cummin(const Tensor & self,int64_t dim)851 std::tuple<Tensor, Tensor> cummin(const Tensor& self, int64_t dim) {
852   auto values = at::empty(self.sizes(), self.options());
853   auto indices = at::empty(self.sizes(), self.options().dtype(at::kLong));
854   at::cummin_out(values, indices, self, dim);
855   return std::make_tuple(values, indices);
856 }
857 
cummaxmin_backward(const Tensor & grad,const Tensor & input,const Tensor & indices,int64_t dim)858 Tensor cummaxmin_backward(const Tensor& grad, const Tensor& input, const Tensor& indices, int64_t dim) {
859   if (input.sym_numel() == 0) {
860     return input;
861   }
862   auto result = at::zeros_symint(input.sym_sizes(), input.options());
863 
864   // for composite compliance, use out-of-place variant of
865   // `scatter_add` if `indices` or `grad` is a Tensor Subclass.
866   if (areAnyTensorSubclassLike({indices, grad})) {
867     return result.scatter_add(dim, indices, grad);
868   }
869   return result.scatter_add_(dim, indices, grad);
870 }
871 
prepend_append_on_dim(const Tensor & self,const std::optional<Tensor> & prepend,const std::optional<Tensor> & append,int64_t dim)872 static Tensor prepend_append_on_dim(const Tensor& self, const std::optional<Tensor>& prepend, const std::optional<Tensor>& append, int64_t dim) {
873   // Helper for diff that handles prepending and appending when at least one is present
874   TORCH_INTERNAL_ASSERT(prepend.has_value() || append.has_value(), "either prepend or append must be have value");
875   if (!prepend.has_value() && append.has_value()) {
876     return at::cat({self, append.value()}, dim);
877   } else if (prepend.has_value() && !append.has_value()) {
878     return at::cat({prepend.value(), self}, dim);
879   } else {
880     return at::cat({prepend.value(), self, append.value()}, dim);
881   }
882 }
883 
diff_check_compatible_shape(const Tensor & self,const std::optional<Tensor> & other,int64_t dim)884 static inline void diff_check_compatible_shape(const Tensor& self, const std::optional<Tensor>&other, int64_t dim) {
885   // Helper for diff that checks whether the shape of the tensor to prepend or append
886   // is compatible with that of input
887   if (other.has_value()) {
888     int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim(), false);
889 
890     TORCH_CHECK(
891         other.value().dim() == self.dim(),
892         "diff expects prepend or append to be the same dimension as input");
893 
894     for (const auto i : c10::irange(other.value().dim())) {
895       if (i == wrapped_dim) {
896         continue;
897       }
898       TORCH_SYM_CHECK(
899           other.value().sym_size(i).sym_eq(self.sym_size(i)),
900           "diff expects the shape of tensor to prepend or append to match that of"
901           " input except along the differencing dimension;"
902           " input.size(", i, ") = ", self.sym_size(i), ", but got"
903           " tensor.size(", i, ") = ", other.value().sym_size(i));
904     }
905   }
906 }
907 
diff_check(const Tensor & self,int64_t n,int64_t dim,const std::optional<Tensor> & prepend,const std::optional<Tensor> & append)908 static inline void diff_check(const Tensor& self, int64_t n, int64_t dim, const std::optional<Tensor>&prepend, const std::optional<Tensor>& append) {
909   // Helper for diff that checks whether its parameters are valid
910   TORCH_CHECK(
911       self.dim() >= 1,
912       "diff expects input to be at least one-dimensional");
913 
914   TORCH_CHECK(
915       n >= 0,
916       "order must be non-negative but got ", n);
917 
918   diff_check_compatible_shape(self, prepend, dim);
919   diff_check_compatible_shape(self, append, dim);
920 }
921 
diff_helper(const Tensor & self,int64_t n,int64_t dim)922 static inline Tensor diff_helper(const Tensor& self, int64_t n, int64_t dim) {
923   if (n == 0) {
924     auto result = at::zeros_like(self);
925     result.copy_(self);
926     return result;
927   }
928 
929   auto out_len = self.sym_size(dim) - 1;
930   auto result = self;
931   bool is_kBool = (self.dtype() == at::kBool);
932   n = n > self.sym_size(dim) ? self.sym_size(dim).guard_int(__FILE__, __LINE__) : n;
933 
934   for (C10_UNUSED const auto i : c10::irange(n)) {
935     if (is_kBool) {
936       result = at::logical_xor(
937         at::narrow_symint(result, dim, 1, out_len),
938         at::narrow_symint(result, dim, 0, out_len)
939       );
940     } else {
941       result = at::narrow_symint(result, dim, 1, out_len) - at::narrow_symint(result, dim, 0, out_len);
942     }
943     out_len = out_len - 1;
944   }
945 
946   return result;
947 }
948 
diff(const Tensor & self,int64_t n,int64_t dim,const std::optional<Tensor> & prepend,const std::optional<Tensor> & append)949 Tensor diff(const Tensor& self, int64_t n, int64_t dim, const std::optional<Tensor>& prepend, const std::optional<Tensor>& append) {
950   diff_check(self, n, dim, prepend, append);
951   if ((!prepend.has_value() && !append.has_value()) || n == 0) {
952     return diff_helper(self, n, dim);
953   } else {
954     auto a = prepend_append_on_dim(self, prepend, append, dim);
955     return diff_helper(a, n, dim);
956   }
957 }
958 
diff_out_helper(const Tensor & self,int64_t n,int64_t dim,Tensor & result)959 static inline Tensor& diff_out_helper(const Tensor& self, int64_t n, int64_t dim, Tensor& result) {
960   if (n == 0) {
961     if (resize_output_check_symint(result, self.sym_sizes())) {
962       result.resize__symint(self.sym_sizes());
963     }
964     check_scalar_type_device_layout_equal(result, self);
965     return result.copy_(self);
966   }
967 
968   n = n > self.sym_size(dim) ? self.sym_size(dim).guard_int(__FILE__, __LINE__) : n;
969   const auto out_len = self.sym_size(dim) - n;
970   auto prev_result = self;
971 
972   if (n > 1) {
973     prev_result = diff_helper(self, n - 1, dim);
974   }
975 
976   if (self.dtype() == at::kBool) {
977     at::logical_xor_out(
978       result,
979       at::narrow_symint(prev_result, dim, 1, out_len),
980       at::narrow_symint(prev_result, dim, 0, out_len)
981     );
982   } else {
983     at::sub_out(
984       result,
985       at::narrow_symint(prev_result, dim, 1, out_len),
986       at::narrow_symint(prev_result, dim, 0, out_len)
987     );
988   }
989 
990   return result;
991 }
992 
diff_out(const Tensor & self,int64_t n,int64_t dim,const std::optional<Tensor> & prepend,const std::optional<Tensor> & append,Tensor & result)993 Tensor& diff_out(const Tensor& self, int64_t n, int64_t dim, const std::optional<Tensor>& prepend, const std::optional<Tensor>& append, Tensor& result) {
994   diff_check(self, n, dim, prepend, append);
995   if ((!prepend.has_value() && !append.has_value()) || n == 0) {
996     return diff_out_helper(self, n, dim, result);
997   } else {
998     auto a = prepend_append_on_dim(self, prepend, append, dim);
999     return diff_out_helper(a, n, dim, result);
1000   }
1001 }
1002 
pre_check_gradient(const Tensor & self,std::optional<int64_t> spacing_size,at::OptionalIntArrayRef dim,int64_t edge_order)1003 static void pre_check_gradient(const Tensor& self, std::optional<int64_t> spacing_size, at::OptionalIntArrayRef dim,  int64_t edge_order) {
1004   // Helper for gradient function to make sure input data satisfies prerequisites
1005   TORCH_CHECK(self.scalar_type() != ScalarType::Byte, "torch.gradient does not support uint8 input.");
1006   if (spacing_size.has_value() && !dim.has_value()) {
1007     // NOTE: If spacing was given as a scalar, the callers of this function
1008     // create a spacing vector of the expected size, and this check passes
1009     TORCH_CHECK(spacing_size.value() == self.dim(),
1010       "torch.gradient expected spacing to be unspecified, a scalar, or a list ",
1011       "of length equal to 'self.dim() = ", self.dim(), "', since dim argument ",
1012       "was not given, but got a list of length ", spacing_size.value());
1013   }
1014   if (spacing_size.has_value() && dim.has_value()) {
1015     TORCH_CHECK(spacing_size.value() == static_cast<int64_t>(dim.value().size()),
1016     "torch.gradient expected spacing to be unspecified, a scalar or it's spacing and dim arguments to have the same length, but got a spacing argument of length ", spacing_size.value(), " and a dim argument of length ", dim.value().size(), "." );
1017   }
1018   TORCH_CHECK(edge_order == 1 || edge_order == 2, "torch.gradient only supports edge_order=1 and edge_order=2.");
1019   if (dim.has_value()) {
1020     // The following function get called to check whether dim argument satisfies prerequisites.
1021     // The output of the function is not used for the computation of gradient.
1022     dim_list_to_bitset(dim.value(), self.dim());
1023     for (const auto i : c10::irange(dim.value().size())) {
1024       TORCH_CHECK(self.size(dim.value()[i]) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
1025     }
1026   } else {
1027     for (const auto i : c10::irange(self.dim())) {
1028       TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
1029     }
1030   }
1031 }
1032 
gradient_helper(const Tensor & self,TensorList coordinates,IntArrayRef dim,int64_t edge_order)1033 static std::vector<Tensor> gradient_helper(const Tensor& self, TensorList coordinates, IntArrayRef dim, int64_t edge_order) {
1034   for (const auto i : c10::irange(coordinates.size())) {
1035     TORCH_CHECK(self.device() == coordinates[i].device(), "torch.gradient expected each tensor to be on the same device, but got devices ", self.device(), " and ", coordinates[i].device(), "!");
1036   }
1037 
1038   std::vector<Tensor> result;
1039   for (const auto i : c10::irange(dim.size())) {
1040     TORCH_CHECK( coordinates[i].dim() == 1, "torch.gradient expected each element of spacing to have one dimension, but got an element with ", coordinates[i].dim(), " dimensions!");
1041     int64_t direction = maybe_wrap_dim(dim[i], self.dim());
1042     Tensor prepend, append;
1043     std::vector<int64_t> shape(self.dim(),1);
1044     shape[ direction ] = -1;
1045 
1046     auto ax_dx = coordinates[i].diff(1,0);
1047     auto dx1 = at::slice(ax_dx, 0, 0, -1);
1048     auto dx2 = at::slice(ax_dx, 0, 1);
1049     auto a = (   -dx2    / (dx1*(dx1+dx2)) ).reshape(shape);
1050     auto b = ( (dx2-dx1) / (dx1*dx2)       ).reshape(shape);
1051     auto c = (    dx1    / (dx2*(dx1+dx2)) ).reshape(shape);
1052 
1053     auto center = a * at::slice(self, direction, 0, -2) + b * at::slice(self, direction , 1, -1) + c * at::slice(self, direction, 2);
1054     if (edge_order == 1) {
1055      prepend = (at::slice(self, direction, 1, 2  ) - at::slice(self, direction, 0, 1   )) / ax_dx[0]  ;
1056      append  = (at::slice(self, direction, -1    ) - at::slice(self, direction, -2, -1 )) / ax_dx[-1] ;
1057     } else if (edge_order == 2) {
1058      a =-(2.0 * ax_dx[0] + ax_dx[1]) / (ax_dx[0] * (ax_dx[0] + ax_dx[1])) ;
1059      b = (      ax_dx[0] + ax_dx[1]) / (ax_dx[0] * ax_dx[1])       ;
1060      c = (     -ax_dx[0]           ) / (ax_dx[1] * (ax_dx[0] + ax_dx[1]));
1061      prepend = a * at::slice(self, direction, 0, 1) + b * at::slice(self, direction, 1, 2) + c * at::slice(self, direction, 2, 3);
1062 
1063      a = (    ax_dx[-1]            ) / (ax_dx[-2] * (ax_dx[-1] + ax_dx[-2]));
1064      b =-(    ax_dx[-1] + ax_dx[-2]) / (ax_dx[-1] * ax_dx[-2]);
1065      c = (2 * ax_dx[-1] + ax_dx[-2]) / (ax_dx[-1] * (ax_dx[-1] + ax_dx[-2]));
1066      append = a * at::slice(self, direction, -3, -2) + b * at::slice(self, direction, -2, -1) + c * at::slice(self, direction, -1);
1067     }
1068 
1069     result.emplace_back(prepend_append_on_dim(center, prepend, append, direction));
1070   }
1071   return result;
1072 }
1073 
gradient_helper_float(const Tensor & self,ArrayRef<Scalar> spacing,IntArrayRef dim,int64_t edge_order)1074 static std::vector<Tensor> gradient_helper_float(const Tensor& self, ArrayRef<Scalar> spacing, IntArrayRef dim, int64_t edge_order) {
1075   std::vector<Tensor> result;
1076   for (const auto i : c10::irange(dim.size())) {
1077       int64_t direction = maybe_wrap_dim(dim[i], self.dim());
1078       const auto& ax_dx = spacing[i];
1079       Tensor prepend, append;
1080       auto center  = (at::slice(self,direction, 2   ) - at::slice(self, direction, 0, -2 ) ) / ax_dx;
1081       if (edge_order==1) {
1082         prepend = (at::slice(self,direction, 1, 2) - at::slice(self, direction, 0, 1  ) ) / ax_dx;
1083         append  = (at::slice(self,direction, -1  ) - at::slice(self, direction, -2, -1) ) / ax_dx ;
1084       } else if (edge_order==2) {
1085         prepend = (-1.5 * at::slice(self, direction, 0, 1) + 2 * at::slice(self, direction, 1, 2)   - 0.5 * at::slice(self, direction, 2, 3))/ ax_dx;
1086         append = (0.5 * at::slice(self, direction, -3, -2) - 2 * at::slice(self, direction, -2, -1) + 1.5 * at::slice(self, direction, -1))  / ax_dx;
1087       }
1088 
1089       result.emplace_back(prepend_append_on_dim(center/2, prepend, append, direction));
1090   }
1091   return result;
1092 }
1093 
gradient_dim_preprocess(const Tensor & self,std::optional<int64_t> dim)1094 static std::vector<int64_t> gradient_dim_preprocess(const Tensor& self, std::optional<int64_t> dim) {
1095   // if gradient dim is provided as an integer, then we need to compute gradient only on this direction.
1096   // Moreover, if it's not provided at all, then we are interested in gradient for all directions.
1097   // Finally, if dim is provided as vector of ints, then it is not expected to be called by this function.
1098   if (dim.has_value()) {
1099     return std::vector<int64_t>{dim.value()};
1100   }
1101 
1102   std::vector<int64_t> axis(self.dim());
1103   std::iota(axis.begin(), axis.end(), 0);
1104   return axis;
1105 }
1106 
gradient(const Tensor & self,TensorList coordinates,IntArrayRef dim,int64_t edge_order)1107 std::vector<Tensor> gradient(const Tensor& self, TensorList coordinates, IntArrayRef dim, int64_t edge_order) {
1108     pre_check_gradient(self,
1109                        std::optional<int64_t>(coordinates.size()),
1110                        at::OptionalIntArrayRef(dim),
1111                        edge_order);
1112     return gradient_helper(self, coordinates, dim, edge_order);
1113 }
1114 
gradient(const Tensor & self,TensorList coordinates,std::optional<int64_t> dim,int64_t edge_order)1115 std::vector<Tensor> gradient(const Tensor& self, TensorList coordinates, std::optional<int64_t> dim, int64_t edge_order) {
1116   const auto processed_dim = gradient_dim_preprocess(self, dim);
1117   pre_check_gradient(self,
1118                      std::optional<int64_t>(coordinates.size()),
1119                      dim.has_value() ? at::OptionalIntArrayRef(processed_dim) : std::nullopt,
1120                      edge_order);
1121   return gradient_helper(self, coordinates, processed_dim, edge_order);
1122 }
1123 
gradient(const Tensor & self,c10::ArrayRef<Scalar> spacing,IntArrayRef dim,int64_t edge_order)1124 std::vector<Tensor> gradient(const Tensor& self, c10::ArrayRef<Scalar> spacing, IntArrayRef dim, int64_t edge_order) {
1125   pre_check_gradient(self,
1126                      std::optional<int64_t>(spacing.size()),
1127                      at::OptionalIntArrayRef(dim),
1128                      edge_order);
1129   return gradient_helper_float(self, spacing, dim, edge_order);
1130 }
1131 
gradient(const Tensor & self,ArrayRef<Scalar> spacing,std::optional<int64_t> dim,int64_t edge_order)1132 std::vector<Tensor> gradient(const Tensor& self, ArrayRef<Scalar> spacing, std::optional<int64_t> dim, int64_t edge_order) {
1133   const auto processed_dim = gradient_dim_preprocess(self, dim);
1134   pre_check_gradient(self,
1135                      std::optional<int64_t>(spacing.size()),
1136                      dim.has_value() ? at::OptionalIntArrayRef(processed_dim) : std::nullopt,
1137                      edge_order);
1138   return gradient_helper_float(self, spacing, processed_dim, edge_order);
1139 }
1140 
gradient(const Tensor & self,const Scalar & unit_size,IntArrayRef dim,int64_t edge_order)1141 std::vector<Tensor> gradient(const Tensor& self, const Scalar& unit_size, IntArrayRef dim, int64_t edge_order) {
1142   // When spacing is given as scalar, while dim is given as IntArrayRef, scalar value need to
1143   // be taken as unit size at every given dimension element of - dim.
1144   std::vector<Scalar> spacing(dim.size(), unit_size);
1145   pre_check_gradient(self,
1146                      std::optional<int64_t>(spacing.size()),
1147                      at::OptionalIntArrayRef(dim),
1148                      edge_order);
1149   return gradient_helper_float(self, spacing, dim, edge_order);
1150 }
1151 
gradient(const Tensor & self,const std::optional<Scalar> & unit_size,std::optional<int64_t> dim,int64_t edge_order)1152 std::vector<Tensor> gradient(const Tensor& self, const std::optional<Scalar>& unit_size, std::optional<int64_t> dim, int64_t edge_order) {
1153   const auto processed_dim = gradient_dim_preprocess(self, dim);
1154   // When unit_size not provided, it is always assumed to be equal to 1.
1155   // When dim has integer value it implies we are looking for gradient in the specific direction, however when
1156   // it is not provided, it means we are interested to find gradient in all directions.
1157   std::vector<Scalar> spacing(dim.has_value() ? 1 : self.dim(),
1158                               unit_size.has_value() ? unit_size.value() : 1.0) ;
1159   pre_check_gradient(self,
1160                      unit_size.has_value() ?  std::optional<int64_t>(spacing.size()) : std::nullopt,
1161                      dim.has_value() ? at::OptionalIntArrayRef(processed_dim) : std::nullopt,
1162                      edge_order);
1163   return gradient_helper_float(self, spacing, processed_dim, edge_order);
1164 }
1165 
gradient(const Tensor & self,IntArrayRef dim,int64_t edge_order)1166 std::vector<Tensor> gradient(const Tensor& self, IntArrayRef dim, int64_t edge_order) {
1167   std::vector<Scalar> spacing(dim.size(), 1.0) ;
1168   pre_check_gradient(self,
1169                      std::optional<int64_t>(spacing.size()),
1170                      at::OptionalIntArrayRef(dim),
1171                      edge_order);
1172   return gradient_helper_float(self, spacing, dim, edge_order);
1173 }
1174 
1175 // ALL REDUCE #################################################################
1176 
should_use_acc_buffer(at::TensorIterator & iter)1177 inline bool should_use_acc_buffer(at::TensorIterator& iter) {
1178   const auto ndim = iter.ndim();
1179   if (!iter.device().is_cpu() || iter.noutputs() != 1) {
1180     return false;
1181   }
1182   if (!at::isReducedFloatingType(iter.common_dtype())) {
1183     return false;
1184   }
1185   if (ndim < 2) {
1186     return false;
1187   }
1188   auto out_strides = iter.strides(0);
1189   for (const auto dim : c10::irange(0, 2)) {
1190       if (out_strides[dim] != 0) {
1191         return false;
1192       }
1193   }
1194   return true;
1195 }
1196 
TORCH_IMPL_FUNC(sum_out)1197 TORCH_IMPL_FUNC(sum_out)
1198 (const Tensor& self,
1199  OptionalIntArrayRef opt_dim,
1200  bool keepdim,
1201  std::optional<ScalarType> opt_dtype,
1202  const Tensor& result) {
1203   auto iter = meta::make_reduction_from_out_ty(self, result, opt_dim, keepdim, result.scalar_type());
1204   if (iter.numel() == 0) {
1205     result.zero_();
1206   } else {
1207     // Here is a limitation of TensorIterator reductions for permuted input with lower precision on CPU.
1208     // Consider the case: TensorIterator coalesces such input and output to >= 2 dims tensors,
1209     // and the output stride is [0, 0, x, x, ...] with x >= 0 (two reduced dimensions and non-reduced dims).
1210     // Since the reduction loop only operates on two dimensions at a time,
1211     // the intermediate sums is forced to do accumulation in the second reduced dim with lower precision.
1212     // See https://github.com/pytorch/pytorch/issues/83149
1213     if (should_use_acc_buffer(iter)) {
1214       auto tmp_output = at::empty(result.sizes(), result.options().dtype(kFloat));
1215       at::sum_outf(self.to(ScalarType::Float), opt_dim, keepdim, /*dtype=*/std::nullopt, tmp_output);
1216       result.copy_(tmp_output);
1217     } else{
1218       sum_stub(iter.device_type(), iter);
1219     }
1220   }
1221 }
1222 
sum(const Tensor & self,std::optional<ScalarType> dtype)1223 Tensor sum(const Tensor &self, std::optional<ScalarType> dtype) {
1224   return at::sum(self, IntArrayRef{}, false, dtype);
1225 }
1226 
sum(const Tensor & self,DimnameList dim,bool keepdim,std::optional<ScalarType> dtype)1227 Tensor sum(const Tensor& self, DimnameList dim, bool keepdim, std::optional<ScalarType> dtype) {
1228   return at::sum(self, dimnames_to_positions(self, dim), keepdim, dtype);
1229 }
1230 
sum_out(const Tensor & self,DimnameList dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1231 Tensor& sum_out(const Tensor& self, DimnameList dim,
1232                 bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
1233   return at::sum_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
1234 }
1235 
nansum_out(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1236 Tensor& nansum_out(const Tensor& self, at::OptionalIntArrayRef dim,
1237                        bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
1238   if (self.device().is_cpu()) {
1239     TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "nansum does not support complex inputs");
1240   }
1241 
1242   // For integral types, use existing sum as
1243   // integral types don't have `Nan`.
1244   if (c10::isIntegralType(self.scalar_type(), true)){
1245     return at::sum_out(result, self, dim, keepdim, opt_dtype);
1246   }
1247 
1248   ScalarType dtype = get_dtype_from_result(result, opt_dtype);
1249   auto iter = make_reduction("nansum", result, self, dim, keepdim, dtype);
1250   if (iter.numel() == 0) {
1251     result = result.zero_();
1252   } else {
1253     nansum_stub(iter.device_type(), iter);
1254   }
1255   return result;
1256 }
1257 
nansum(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype)1258 Tensor nansum(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
1259   ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
1260   Tensor result = create_reduction_result(self, dim, keepdim, dtype);
1261   return at::native::nansum_out(self, dim, keepdim, dtype, result);
1262 }
1263 
1264 namespace {
1265 template<typename scalar_t, typename accscalar_t = at::acc_type<scalar_t, false>>
set_result(Tensor & result,accscalar_t sum)1266 void inline set_result(Tensor& result, accscalar_t sum)
1267 {
1268     if constexpr (std::is_integral_v<accscalar_t>) {
1269       // all integer types get promoted to kLong
1270       *result.data_ptr<int64_t>() = sum;
1271     } else {
1272       *result.data_ptr<scalar_t>() = sum;
1273     }
1274 }
1275 }
1276 // NOTE: this could be implemented via diag and sum, but this has perf problems,
1277 // see https://github.com/pytorch/pytorch/pull/47305,
trace_cpu(const Tensor & self)1278 Tensor trace_cpu(const Tensor& self) {
1279   Tensor result;
1280   // Returns the ScalarType of the self tensor if the tensor is non integral type
1281   // In the case, self is an integer type tensor, at::kLong is return since promote_integers
1282   // is set to true
1283   ScalarType dtype = get_dtype_from_self(self, std::nullopt, true);
1284   result = at::empty({}, self.options().dtype(dtype));
1285   AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "trace", [&] {
1286     using accscalar_t = at::acc_type<scalar_t, false>;
1287     accscalar_t sum = 0;
1288     const auto* t_data = self.const_data_ptr<scalar_t>();
1289 
1290     int64_t t_stride_0, t_stride_1, t_diag_size;
1291 
1292     TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim());
1293 
1294     t_stride_0 = self.stride(0);
1295     t_stride_1 = self.stride(1);
1296 
1297     t_diag_size = std::min(self.size(0), self.size(1));
1298     for (const auto i : c10::irange(t_diag_size)) {
1299       sum += t_data[i * (t_stride_0 + t_stride_1)];
1300     }
1301     set_result<scalar_t>(result, sum);
1302 
1303   });
1304 
1305   return result;
1306 }
1307 
impl_func_prod(const Tensor & self,IntArrayRef dims,bool keepdim,std::optional<ScalarType> dtype,const Tensor & result)1308 static void impl_func_prod(
1309     const Tensor& self,
1310     IntArrayRef dims,
1311     bool keepdim,
1312     std::optional<ScalarType> dtype,
1313     const Tensor& result) {
1314   auto iter = meta::make_reduction_from_out_ty(self, result, dims, keepdim, result.scalar_type());
1315   if (iter.numel() == 0) {
1316     result.fill_(1);
1317   } else {
1318     prod_stub(iter.device_type(), iter);
1319   }
1320 }
1321 
TORCH_IMPL_FUNC(prod_out)1322 TORCH_IMPL_FUNC(prod_out)
1323 (const Tensor& self,
1324  int64_t dim,
1325  bool keepdim,
1326  std::optional<ScalarType> dtype,
1327  const Tensor& result) {
1328   impl_func_prod(self, dim, keepdim, dtype, result);
1329 }
1330 
prod(const Tensor & self,std::optional<ScalarType> opt_dtype)1331 Tensor prod(const Tensor &self, std::optional<ScalarType> opt_dtype) {
1332   auto dtype = get_dtype_from_self(self, opt_dtype, true);
1333   auto shape = meta::get_reduction_shape(self, {}, false);
1334   Tensor result = at::empty(shape, self.options().dtype(dtype));
1335   impl_func_prod(self, {}, false, dtype, result);
1336   return result;
1337 }
1338 
prod(const Tensor & self,Dimname dim,bool keepdim,std::optional<ScalarType> dtype)1339 Tensor prod(const Tensor& self, Dimname dim, bool keepdim, std::optional<ScalarType> dtype) {
1340   return at::prod(self, dimname_to_position(self, dim), keepdim, dtype);
1341 }
1342 
prod_out(const Tensor & self,Dimname dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1343 Tensor& prod_out(const Tensor& self, Dimname dim,
1344                  bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
1345   return at::prod_out(result, self, dimname_to_position(self, dim), keepdim, opt_dtype);
1346 }
1347 
TORCH_IMPL_FUNC(mean_out)1348 TORCH_IMPL_FUNC(mean_out)
1349 (const Tensor& self,
1350  OptionalIntArrayRef opt_dim,
1351  bool keepdim,
1352  std::optional<ScalarType> opt_dtype,
1353  const Tensor& result) {
1354   ScalarType dtype = result.scalar_type();
1355   // TODO: the TensorIterator reduction implementation of mean
1356   // (mean_kernel_impl()) is unvectorized and leads to very poor performance
1357   // for production workloads. Once that's fixed, the following code can be used
1358   // in lieu of the sum + divide implementation below.
1359   if (self.device().is_cpu()) {
1360     int64_t dim_prod = 1;
1361     if (!opt_dim.has_value() || opt_dim.value().empty() || self.ndimension() == 0) {
1362       dim_prod = self.numel();
1363     } else {
1364       auto dim = opt_dim.value();
1365       for (auto d : dim) {
1366         dim_prod *= self.size(d);
1367       }
1368     }
1369     auto& result_mut = const_cast<Tensor&>(result);
1370     // For accuracy reasons, BF16/FP16 mean should be computed via the
1371     // following approach:
1372     //  cast_fp32 -> sum -> div -> cast_bf16_or_fp16
1373     //
1374     // Such an approach is necessary because if we were to choose the same
1375     // approach for BF16/FP16 as FP32 here, then it would have resulted in
1376     // the following code-flow -
1377     // cast_fp32 -> sum -> cast_bf16 -> cast_fp32 -> div -> cast_bf16,
1378     // which, in turn, does not produce as accurate results.
1379     bool is_half_type = (dtype == kHalf || dtype == kBFloat16);
1380     auto sum_out_dtype = is_half_type ? ScalarType::Float : dtype;
1381     result_mut = is_half_type ? result_mut.to(sum_out_dtype) : result_mut;
1382     // If dtype is FP16 or BF16, self (input tensor) will initially be cast to
1383     // FP32 in sum_out. This results in having to read that FP32 tensor again,
1384     // but maybe in the future, we could revise the implementation to not
1385     // materialize that intermediate FP32 tensor. That approach would probably
1386     // require some modifications in binary_kernel_reduce_vec(),
1387     // TensorIteratorBase::for_each(), and
1388     // TensorIteratorBase::serial_for_each(), apart from sum kernel for CPU.
1389     at::sum_out(result_mut, self, opt_dim, keepdim, sum_out_dtype).div_(dim_prod);
1390     // After sum & div, cast result_mut back to BF16 or FP16, if required.
1391     result_mut = is_half_type ? result_mut.to(dtype) : result_mut;
1392   } else {
1393     // device is not CPU
1394     auto iter = at::meta::make_reduction_from_out_ty(
1395         self, result, opt_dim, keepdim, dtype);
1396     if (iter.numel() == 0) {
1397       result.fill_(std::numeric_limits<double>::quiet_NaN());
1398     } else {
1399       mean_stub(iter.device_type(), iter);
1400     }
1401   }
1402 }
1403 
mean(const Tensor & self,std::optional<ScalarType> dtype)1404 Tensor mean(const Tensor &self, std::optional<ScalarType> dtype) {
1405   return at::mean(self, IntArrayRef{}, false, dtype);
1406 }
1407 
mean(const Tensor & self,DimnameList dim,bool keepdim,std::optional<ScalarType> dtype)1408 Tensor mean(const Tensor& self, DimnameList dim, bool keepdim, std::optional<ScalarType> dtype) {
1409   return at::mean(self, dimnames_to_positions(self, dim), keepdim, dtype);
1410 }
1411 
mean_out(const Tensor & self,DimnameList dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1412 Tensor& mean_out(const Tensor& self, DimnameList dim,
1413                  bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
1414   return at::mean_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
1415 }
1416 
mean_dtype_out(const Tensor & self,std::optional<ScalarType> dtype,Tensor & result)1417 Tensor& mean_dtype_out(const Tensor &self, std::optional<ScalarType> dtype, Tensor& result) {
1418   TORCH_CHECK(
1419     canCast(self.scalar_type(), result.scalar_type()),
1420       "mean.dtype_out(): input types can't be cast to the desired output type ",
1421       result.scalar_type());
1422   // at::mean_out should make sure dtype and result.scalar_type() are the same
1423   return at::mean_out(result, self, IntArrayRef{}, false, dtype);
1424 }
1425 
1426 // TODO(@heitorschueroff) implement custom kernels for nanmean
nanmean_out(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1427 Tensor& nanmean_out(
1428     const Tensor& self,
1429     at::OptionalIntArrayRef dim,
1430     bool keepdim,
1431     std::optional<ScalarType> opt_dtype,
1432     Tensor& result) {
1433   TORCH_CHECK(
1434       self.is_floating_point() || self.is_complex(),
1435       "nanmean(): expected input to have floating point or complex dtype but got ",
1436       self.scalar_type());
1437   const auto factor = at::native::isnan(self).logical_not_().sum(dim, keepdim);
1438   at::native::nansum_out(self, dim, keepdim, opt_dtype, result).div_(factor);
1439   return result;
1440 }
1441 
nanmean(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype)1442 Tensor nanmean(
1443     const Tensor& self,
1444     at::OptionalIntArrayRef dim,
1445     bool keepdim,
1446     std::optional<ScalarType> opt_dtype) {
1447   TORCH_CHECK(
1448       self.is_floating_point() || self.is_complex(),
1449       "nanmean(): expected input to have floating point or complex dtype but got ",
1450       self.scalar_type());
1451   const auto factor =
1452       at::native::isnan(self.detach()).logical_not_().sum(dim, keepdim);
1453   return at::nansum(self, dim, keepdim, opt_dtype).div(factor);
1454 }
1455 
logsumexp_out_impl(Tensor & result,const Tensor & self,IntArrayRef dims,bool keepdim)1456 static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
1457   // can't take max of empty tensor
1458   if (self.numel() != 0) {
1459     // For complex numbers, use the real part to calculate the max. Based on
1460     // https://scicomp.stackexchange.com/questions/34273/log-sum-exp-trick-for-signed-complex-numbers
1461     auto maxes = at::amax(at::real(self), dims, true);
1462     auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims));
1463     maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
1464     at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
1465     result.log_().add_(maxes_squeezed);
1466   } else {
1467     at::sum_out(result, at::exp(self), dims, keepdim);
1468     result.log_();
1469   }
1470   return result;
1471 }
1472 
logsumexp_out(const Tensor & self,IntArrayRef dims,bool keepdim,Tensor & result)1473 Tensor& logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
1474   // Complex type implies floating point type
1475   TORCH_CHECK(at::isFloatingType(result.scalar_type()) || at::isComplexType(result.scalar_type()),
1476               "logsumexp(): Expected floating point type for result tensor, but got: ",
1477               result.scalar_type());
1478   {
1479     NoNamesGuard guard;
1480     if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
1481       // for integral inputs, promote input to default floating type.
1482       auto default_dtype = at::typeMetaToScalarType(c10::get_default_dtype());
1483       logsumexp_out_impl(result, self.to(default_dtype), dims, keepdim);
1484     } else {
1485       logsumexp_out_impl(result, self, dims, keepdim);
1486     }
1487   }
1488   namedinference::propagate_names_for_reduction(result, self, dims, keepdim);
1489   return result;
1490 }
1491 
logsumexp(const Tensor & self,IntArrayRef dims,bool keepdim)1492 Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {
1493   TensorOptions result_options;
1494   if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
1495     // even for integral inputs, result is floating dtype
1496     auto default_dtype = at::typeMetaToScalarType(c10::get_default_dtype());
1497     result_options = self.options().dtype(default_dtype);
1498   } else {
1499     result_options = self.options();
1500   }
1501   auto result = at::empty({0}, result_options);
1502   return at::logsumexp_outf(self, dims, keepdim, result);
1503 }
1504 
logsumexp(const Tensor & self,DimnameList dims,bool keepdim)1505 Tensor logsumexp(const Tensor& self, DimnameList dims, bool keepdim) {
1506   return at::logsumexp(self, dimnames_to_positions(self, dims), keepdim);
1507 }
1508 
logsumexp_out(const Tensor & self,DimnameList dims,bool keepdim,Tensor & result)1509 Tensor& logsumexp_out(const Tensor& self, DimnameList dims, bool keepdim, Tensor& result) {
1510   return at::logsumexp_out(result, self, dimnames_to_positions(self, dims), keepdim);
1511 }
1512 
1513 // special_logsumexp, alias for logsumexp
special_logsumexp(const Tensor & self,IntArrayRef dims,bool keepdim)1514 Tensor special_logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {
1515   return self.logsumexp(dims, keepdim);
1516 }
special_logsumexp_out(const Tensor & self,IntArrayRef dims,bool keepdim,Tensor & result)1517 Tensor& special_logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
1518   return at::logsumexp_out(result, self, dims, keepdim);
1519 }
1520 
impl_func_norm(const Tensor & self,const OptionalScalarRef & opt_p,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,const Tensor & result)1521 static void impl_func_norm(
1522     const Tensor& self,
1523     const OptionalScalarRef& opt_p,
1524     IntArrayRef dim,
1525     bool keepdim,
1526     std::optional<ScalarType> opt_dtype,
1527     const Tensor& result) {
1528   // Left this implementation without deprecating it as it is called in a number of places
1529   // in the codebase. We should swap those by linalg_vector_norm
1530   auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to<double>();
1531   at::linalg_vector_norm_out(const_cast<Tensor&>(result), self, p, dim, keepdim, opt_dtype);
1532 }
1533 
TORCH_IMPL_FUNC(norm_out)1534 TORCH_IMPL_FUNC(norm_out)
1535 (const Tensor& self,
1536  const OptionalScalarRef p,
1537  IntArrayRef dim,
1538  bool keepdim,
1539  const Tensor& result) {
1540   impl_func_norm(self, p, dim, keepdim, std::nullopt, result);
1541 }
1542 
TORCH_IMPL_FUNC(norm_dtype_out)1543 TORCH_IMPL_FUNC(norm_dtype_out)
1544 (const Tensor& self,
1545  const OptionalScalarRef p,
1546  IntArrayRef dim,
1547  bool keepdim,
1548  ScalarType dtype,
1549  const Tensor& result) {
1550   impl_func_norm(self, p, dim, keepdim, dtype, result);
1551 }
1552 
sparse_norm(const Tensor & self,const std::optional<Scalar> & p,IntArrayRef dim,bool keepdim)1553 Tensor sparse_norm(
1554     const Tensor& self,
1555     const std::optional<Scalar>& p,
1556     IntArrayRef dim,
1557     bool keepdim) {
1558   return at::native_norm(self, p, dim, keepdim, std::nullopt);
1559 }
1560 
sparse_dtype_norm(const Tensor & self,const std::optional<Scalar> & p,IntArrayRef dim,bool keepdim,ScalarType dtype)1561 Tensor sparse_dtype_norm(
1562     const Tensor& self,
1563     const std::optional<Scalar>& p,
1564     IntArrayRef dim,
1565     bool keepdim,
1566     ScalarType dtype) {
1567   return at::native_norm(self, p, dim, keepdim, dtype);
1568 }
1569 
norm(const Tensor & self,const std::optional<Scalar> & p,ScalarType dtype)1570 Tensor norm(const Tensor& self, const std::optional<Scalar>& p, ScalarType dtype) {
1571   return at::norm(self, p, IntArrayRef{}, false, dtype);
1572 }
1573 
norm(const Tensor & self,const Scalar & p)1574 Tensor norm(const Tensor& self, const Scalar& p) {
1575   return at::norm(self, p, IntArrayRef{}, false);
1576 }
1577 
get_allany_iter(const Tensor & self,const Tensor & result,OptionalIntArrayRef dims,bool keepdim)1578 inline TensorIterator get_allany_iter(
1579     const Tensor& self,
1580     const Tensor& result,
1581     OptionalIntArrayRef dims,
1582     bool keepdim) {
1583   if (self.is_cuda()) {
1584     // As CUDA supports dynamic type casting, we use this overload of
1585     // `make_reduction`, which doesn't cast input to the result type i.e. kBool.,
1586     // otherwise we use the overload below which casts the input to kBool (which is
1587     // an extra operation).
1588     return meta::make_reduction(self, result, dims, keepdim, self.scalar_type());
1589   }
1590   return meta::make_reduction_from_out_ty(
1591       self, result, dims, keepdim, result.scalar_type());
1592 }
1593 
1594 template <int identity, typename Stub>
allany_impl(const Tensor & self,const Tensor & result,OptionalIntArrayRef dims,bool keepdim,Stub & stub)1595 inline void allany_impl(
1596     const Tensor& self,
1597     const Tensor& result,
1598     OptionalIntArrayRef dims,
1599     bool keepdim,
1600     Stub& stub) {
1601   if (self.numel() == 0) {
1602     result.fill_(identity);
1603   } else if (self.numel() == 1) {
1604     result.copy_(self.view_as(result).to(at::kBool));
1605   } else {
1606     auto iter = get_allany_iter(self, result, dims, keepdim);
1607     stub(iter.device_type(), iter);
1608   }
1609 }
1610 
TORCH_IMPL_FUNC(all_out)1611 TORCH_IMPL_FUNC(all_out)
1612 (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
1613   allany_impl<1>(self, result, dim, keepdim, and_stub);
1614 }
1615 
TORCH_IMPL_FUNC(all_dims_out)1616 TORCH_IMPL_FUNC(all_dims_out)
1617 (const Tensor& self, OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
1618   allany_impl<1>(self, result, dim, keepdim, and_stub);
1619 }
1620 
TORCH_IMPL_FUNC(all_all_out)1621 TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
1622   allany_impl<1>(self, result, {}, false, and_stub);
1623 }
1624 
TORCH_IMPL_FUNC(any_out)1625 TORCH_IMPL_FUNC(any_out)
1626 (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
1627   allany_impl<0>(self, result, dim, keepdim, or_stub);
1628 }
1629 
TORCH_IMPL_FUNC(any_dims_out)1630 TORCH_IMPL_FUNC(any_dims_out)
1631 (const Tensor& self, OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
1632   allany_impl<0>(self, result, dim, keepdim, or_stub);
1633 }
1634 
TORCH_IMPL_FUNC(any_all_out)1635 TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
1636   allany_impl<0>(self, result, {}, false, or_stub);
1637 }
1638 
1639 template <bool is_all>
allany_dims_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim)1640 Tensor allany_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
1641   // Default implementation in terms of all-reduce or single dim reduce
1642   if (!dim) {
1643     Tensor out;
1644     if constexpr (is_all) {
1645       out = self.all();
1646     } else {
1647       out = self.any();
1648     }
1649 
1650     if (keepdim) {
1651       DimVector out_shape(self.dim(), 1);
1652       return out.expand(out_shape);
1653     }
1654     return out;
1655   }
1656 
1657   if (dim->empty()) {
1658     if (self.scalar_type() == kByte) {
1659       // Convert to a 1 or 0 mask
1660       auto out = at::empty_like(self);
1661       return at::ne_outf(self, 0, out);
1662     } else {
1663       return at::_to_copy(self, kBool);
1664     }
1665   }
1666 
1667   Tensor out = self;
1668   for (auto d : *dim) {
1669     if constexpr (is_all) {
1670       out = out.all(d, /*keepdim=*/true);
1671     } else {
1672       out = out.any(d, /*keepdim=*/true);
1673     }
1674   }
1675   return keepdim ? out : out.squeeze(*dim);
1676 }
1677 
all_dims_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim)1678 Tensor all_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
1679   return allany_dims_default<true>(self, dim, keepdim);
1680 }
1681 
any_dims_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim)1682 Tensor any_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
1683   return allany_dims_default<false>(self, dim, keepdim);
1684 }
1685 
all_dims_out_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim,Tensor & result)1686 Tensor& all_dims_out_default(
1687     const Tensor &self, OptionalIntArrayRef dim, bool keepdim, Tensor &result) {
1688   TORCH_CHECK(self.device() == result.device(), "all.dims: output must be on the same device as input");
1689   auto tmp = all_dims_default(self, dim, keepdim);
1690   at::native::resize_output(result, tmp.sizes());
1691   return result.copy_(tmp);
1692 }
1693 
any_dims_out_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim,Tensor & result)1694 Tensor& any_dims_out_default(
1695     const Tensor &self, OptionalIntArrayRef dim, bool keepdim, Tensor &result) {
1696   TORCH_CHECK(self.device() == result.device(), "any.dims: output must be on the same device as input");
1697   auto tmp = any_dims_default(self, dim, keepdim);
1698   at::native::resize_output(result, tmp.sizes());
1699   return result.copy_(tmp);
1700 }
1701 
TORCH_IMPL_FUNC(amin_out)1702 TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
1703   auto iter =
1704       meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
1705   if (iter.numel() != 0) {
1706     min_values_stub(iter.device_type(), iter);
1707   }
1708 }
1709 
TORCH_IMPL_FUNC(amax_out)1710 TORCH_IMPL_FUNC(amax_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
1711   auto iter =
1712       meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
1713   if (iter.numel() != 0) {
1714     max_values_stub(iter.device_type(), iter);
1715   }
1716 }
1717 
1718 template <class Stub>
argmax_argmin_impl(const Tensor & self,std::optional<int64_t> dim,bool keepdim,const Tensor & result,Stub & stub)1719 void argmax_argmin_impl(
1720     const Tensor& self,
1721     std::optional<int64_t> dim,
1722     bool keepdim,
1723     const Tensor& result,
1724     Stub& stub) {
1725   c10::MaybeOwned<Tensor> in;
1726   DimVector dims;
1727   int64_t _dim = 0;
1728 
1729   if (dim.has_value()) {
1730     _dim = maybe_wrap_dim(dim.value(), self.dim());
1731     auto sizes = self.sizes();
1732 
1733     if (sizes[_dim] == 1) {
1734       result.fill_(0);
1735       return;
1736     }
1737 
1738     dims = IntArrayRef(_dim);
1739     in = c10::MaybeOwned<Tensor>::borrowed(self);
1740   } else {
1741     in = c10::MaybeOwned<Tensor>::owned(self.reshape({-1}));
1742     keepdim = false;
1743   }
1744 
1745   auto iter =
1746       meta::make_reduction(*in, result, dims, keepdim, self.scalar_type());
1747 
1748   if (iter.numel() != 0) {
1749     stub(iter.device_type(), iter);
1750   }
1751 }
1752 
TORCH_IMPL_FUNC(argmax_out)1753 TORCH_IMPL_FUNC(argmax_out)
1754 (const Tensor& self,
1755  std::optional<int64_t> dim,
1756  bool keepdim,
1757  const Tensor& result) {
1758   argmax_argmin_impl(self, dim, keepdim, result, argmax_stub);
1759 }
1760 
TORCH_IMPL_FUNC(argmin_out)1761 TORCH_IMPL_FUNC(argmin_out)
1762 (const Tensor& self,
1763  std::optional<int64_t> dim,
1764  bool keepdim,
1765  const Tensor& result) {
1766   argmax_argmin_impl(self, dim, keepdim, result, argmin_stub);
1767 }
1768 
std_var_all_cpu(const Tensor & self,double correction,bool take_sqrt)1769 static double std_var_all_cpu(const Tensor& self, double correction, bool take_sqrt) {
1770   const auto dtype = self.scalar_type();
1771   TORCH_CHECK(dtype == kDouble || dtype == kFloat,
1772               "std_var_all: Unsupported dtype ", dtype);
1773 
1774   auto mean = self.mean().item<double>();
1775   auto iter = TensorIteratorConfig()
1776       .add_const_input(self)
1777       .build();
1778 
1779   auto reduction = [&](int64_t begin, int64_t end, double thread_sum) {
1780     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "std_var_all_cpu", [&] {
1781       iter.serial_for_each([&] (char** data, const int64_t* strides, int64_t size0, int64_t size1) {
1782         const double local_mean = mean;
1783         const int64_t inner_stride = strides[0];
1784         const int64_t outer_stride = strides[1];
1785 
1786         double local_sum = 0.0;
1787         for (const auto i : c10::irange(size1)) {
1788           const char* row_ptr = data[0] + outer_stride * i;
1789           for (const auto j : c10::irange(size0)) {
1790             const auto ptr = reinterpret_cast<const scalar_t*>(row_ptr + inner_stride * j);
1791             auto dx = (static_cast<double>(*ptr) - local_mean);
1792             local_sum += dx * dx;
1793           }
1794         }
1795         thread_sum += local_sum;
1796       }, {begin, end});
1797     });
1798 
1799     return thread_sum;
1800   };
1801 
1802   // ((x - mean)**2).sum()
1803   const double sum_dx2 = at::parallel_reduce(
1804       0, iter.numel(), at::internal::GRAIN_SIZE, 0.0, reduction, std::plus<>{});
1805 
1806   const auto var = [&] () __ubsan_ignore_float_divide_by_zero__ {
1807     return sum_dx2 / std::max(0.0, self.numel() - correction);
1808   }();
1809   const auto result = take_sqrt ? std::sqrt(var) : var;
1810 
1811   if (dtype == kFloat) {
1812     // Convert to infinity if out of range for a float.
1813     // Doing it now prevents checked_convert failing later
1814     return static_cast<float>(result);
1815   }
1816   return result;
1817 }
1818 
1819 namespace {
warn_invalid_degrees_of_freedom(const char * fname,const TensorIterator & iter,double correction)1820   inline void warn_invalid_degrees_of_freedom(const char* fname, const TensorIterator& iter, double correction) {
1821     int64_t reducing_over_num_elements = iter.num_output_elements() == 0 ? 0 : iter.numel() / iter.num_output_elements();
1822     if (reducing_over_num_elements - correction <= 0) {
1823       TORCH_WARN(fname, "(): degrees of freedom is <= 0. Correction should be strictly less than the reduction factor (input numel divided by output numel).");
1824     }
1825   }
1826 } // namespace
1827 
std_var_out(const char * fname,Tensor & result,const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction_opt,bool keepdim,bool take_sqrt)1828 static Tensor& std_var_out(
1829     const char* fname, Tensor& result, const Tensor& self,
1830     at::OptionalIntArrayRef dim, const std::optional<Scalar>& correction_opt,
1831     bool keepdim, bool take_sqrt) {
1832   TORCH_CHECK(self.device().is_cpu() || self.device().is_cuda() || self.device().is_xpu(),
1833               "std and var supports tensors on a CPU, CUDA, or XPU device only, but got: ",
1834               self.device().type());
1835   TORCH_CHECK(self.layout() == Layout::Strided,
1836               "std and var only supports strided layout, got: ", self.layout());
1837   TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
1838               "std and var only support floating point and complex dtypes");
1839 
1840   if (at::isComplexType(self.scalar_type())) {
1841     // For complex, calculate variance of real and imaginary components
1842     // separately then add to get overall variance.
1843     ScalarType dtype = c10::toRealValueType(get_dtype_from_result(result, {}));
1844     Tensor real_in = at::real(self);
1845     Tensor real_out = at::empty({0}, self.options().dtype(dtype));
1846     std_var_out(
1847         fname,
1848         real_out,
1849         real_in,
1850         dim,
1851         correction_opt,
1852         keepdim,
1853         /*take_sqrt=*/false);
1854 
1855     Tensor imag_in = at::imag(self);
1856     Tensor imag_out = at::empty({0}, self.options().dtype(dtype));
1857     std_var_out(
1858         fname,
1859         imag_out,
1860         imag_in,
1861         dim,
1862         correction_opt,
1863         keepdim,
1864         /*take_sqrt=*/false);
1865 
1866     at::add_out(result, real_out, imag_out);
1867     if (take_sqrt) {
1868       at::sqrt_out(result, result);
1869     }
1870     return result;
1871   }
1872 
1873   // Computation for floating point
1874   const auto correction = correction_opt.value_or(1).toDouble();
1875   ScalarType dtype = get_dtype_from_result(result, {});
1876   auto iter = make_reduction(fname, result, self, dim, keepdim, dtype);
1877   TORCH_CHECK(at::canCast(self.scalar_type(), result.scalar_type()),
1878               "result type ", self.scalar_type(), " can't be cast to the "
1879               "desired output type ", result.scalar_type());
1880   warn_invalid_degrees_of_freedom(fname, iter, correction);
1881 
1882   if (iter.numel() == 0) {
1883     // Trivial reduction
1884     result.fill_(std::numeric_limits<double>::quiet_NaN());
1885     return result;
1886   } else if (
1887       result.numel() == 1 && iter.device_type() == kCPU &&
1888       iter.common_dtype() != kBFloat16 && iter.common_dtype() != kHalf) {
1889     // NOTE: CPU performance significantly regressed when attempting to port to
1890     // ATen,
1891     //   so all-reduce has a custom implementation.
1892     //   See https://github.com/pytorch/pytorch/pull/43858.
1893     result.fill_(std_var_all_cpu(self, correction, take_sqrt));
1894   } else {
1895     std_var_stub(iter.device_type(), iter, correction, take_sqrt);
1896   }
1897   return result;
1898 }
1899 
std_var_mean_out(const char * fname,Tensor & result1,Tensor & result2,const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction_opt,bool keepdim,bool take_sqrt)1900 static std::tuple<Tensor&, Tensor&> std_var_mean_out(
1901     const char* fname, Tensor& result1, Tensor& result2, const Tensor& self,
1902     at::OptionalIntArrayRef dim, const std::optional<Scalar>& correction_opt,
1903     bool keepdim, bool take_sqrt) {
1904   AT_ASSERT(result1.defined() && result2.defined());
1905   TORCH_CHECK(self.device().is_cpu() || self.is_cuda() || self.is_xpu(),
1906               fname, " supports tensors on a CPU, CUDA, or XPU device only, got: ",
1907               self.device().type());
1908   TORCH_CHECK(self.layout() == Layout::Strided,
1909               fname, " only supports strided layout, got: ", self.layout());
1910   TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
1911               fname, " only support floating point and complex dtypes");
1912   TORCH_CHECK(result1.scalar_type() == c10::toRealValueType(result2.scalar_type()),
1913               fname, " expected result1 to be real and match the precision of result2. Got ",
1914               result1.scalar_type(), " and ", result2.scalar_type(), ".");
1915 
1916   if (at::isComplexType(self.scalar_type())) {
1917     // For complex, calculate for real and imaginary components separately then combine as:
1918     // variance = var_real + var_imag
1919     // mean = mean_real + j * mean_imag
1920     ScalarType dtype = c10::toRealValueType(get_dtype_from_result(result1, {}));
1921     Tensor real_in = at::real(self);
1922     Tensor real_out_var = at::empty({0}, self.options().dtype(dtype));
1923     Tensor real_out_mean = at::empty({0}, self.options().dtype(dtype));
1924     std_var_mean_out(
1925         fname,
1926         real_out_var,
1927         real_out_mean,
1928         real_in,
1929         dim,
1930         correction_opt,
1931         keepdim,
1932         /*take_sqrt=*/false);
1933 
1934     Tensor imag_in = at::imag(self);
1935     Tensor imag_out_var = at::empty({0}, self.options().dtype(dtype));
1936     Tensor imag_out_mean = at::empty({0}, self.options().dtype(dtype));
1937     std_var_mean_out(
1938         fname,
1939         imag_out_var,
1940         imag_out_mean,
1941         imag_in,
1942         dim,
1943         correction_opt,
1944         keepdim,
1945         /*take_sqrt=*/false);
1946 
1947     at::add_out(result1, real_out_var, imag_out_var);
1948     if (take_sqrt) {
1949       at::sqrt_out(result1, result1);
1950     }
1951     at::complex_out(result2, real_out_mean, imag_out_mean);
1952     return std::tuple<Tensor&, Tensor&>(result1, result2);
1953   }
1954 
1955   // Computation for floating point
1956   const auto correction = correction_opt.value_or(1).toDouble();
1957   ScalarType dtype = get_dtype_from_result(result1, {});
1958   auto iter =
1959       make_reduction(fname, result1, result2, self, dim, keepdim, dtype);
1960   warn_invalid_degrees_of_freedom(fname, iter, correction);
1961 
1962   if (iter.numel() == 0) {
1963     // Trivial reduction
1964     result1.fill_(std::numeric_limits<double>::quiet_NaN());
1965     result2.fill_(std::numeric_limits<double>::quiet_NaN());
1966   } else {
1967     std_var_stub(iter.device_type(), iter, correction, take_sqrt);
1968   }
1969   return std::tuple<Tensor&, Tensor&>(result1, result2);
1970 }
1971 
var_mean(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)1972 std::tuple<Tensor, Tensor> var_mean(
1973     const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
1974   return at::var_mean(
1975       self, /*dim=*/at::OptionalIntArrayRef(dim),
1976       /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
1977       keepdim);
1978 }
1979 
std_mean(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)1980 std::tuple<Tensor, Tensor> std_mean(
1981     const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
1982   return at::std_mean(
1983       self, /*dim=*/at::OptionalIntArrayRef(dim),
1984       /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
1985       keepdim);
1986 }
1987 
std_mean(const Tensor & self,bool unbiased)1988 std::tuple<Tensor, Tensor> std_mean(const Tensor& self, bool unbiased) {
1989   return at::std_mean(
1990       self, /*dim=*/std::nullopt,
1991       /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
1992 }
1993 
var_mean(const Tensor & self,bool unbiased)1994 std::tuple<Tensor, Tensor> var_mean(const Tensor& self, bool unbiased) {
1995   return at::var_mean(
1996       self, /*dim=*/std::nullopt,
1997       /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
1998 }
var_mean_out(Tensor & result1,Tensor & result2,const Tensor & self,IntArrayRef dim,int64_t correction,bool keepdim)1999 std::tuple<Tensor&, Tensor&> var_mean_out(
2000     Tensor& result1, Tensor& result2, const Tensor& self, IntArrayRef dim,
2001     int64_t correction, bool keepdim) {
2002   return std_var_mean_out(
2003       "var_mean", result1, result2, self, dim, correction, keepdim, false);
2004 }
2005 
options_to_value_type(TensorOptions opts)2006 static TensorOptions options_to_value_type(TensorOptions opts) {
2007   auto scalar_type = typeMetaToScalarType(opts.dtype());
2008   return opts.dtype(c10::toRealValueType(scalar_type));
2009 }
2010 
var_mean(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim)2011 std::tuple<Tensor, Tensor> var_mean(
2012     const Tensor& self, at::OptionalIntArrayRef dim,
2013     const std::optional<Scalar>& correction, bool keepdim) {
2014   Tensor result1 = at::empty({0}, options_to_value_type(self.options()));
2015   Tensor result2 = at::empty({0}, self.options());
2016   return std_var_mean_out(
2017       "var_mean", result1, result2, self, dim, correction, keepdim, false);
2018 }
2019 
std_mean(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim)2020 std::tuple<Tensor, Tensor> std_mean(
2021     const Tensor& self, at::OptionalIntArrayRef dim,
2022     const std::optional<Scalar>& correction, bool keepdim) {
2023   Tensor result1 = at::empty({0}, options_to_value_type(self.options()));
2024   Tensor result2 = at::empty({0}, self.options());
2025   return std_var_mean_out(
2026       "std_mean", result1, result2, self, dim, correction, keepdim, true);
2027 }
2028 
var(const Tensor & self,bool unbiased)2029 Tensor var(const Tensor& self, bool unbiased) {
2030   return at::var(
2031       self, /*dim=*/std::nullopt,
2032       /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
2033 }
2034 
var(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)2035 Tensor var(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
2036   return at::var(
2037       self, /*dim=*/at::OptionalIntArrayRef(dim),
2038       /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
2039       keepdim);
2040 }
2041 
var_out(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim,Tensor & result)2042 Tensor& var_out(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, Tensor& result) {
2043   return at::var_out(
2044       result, self, /*dim=*/at::OptionalIntArrayRef(dim),
2045       /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
2046       keepdim);
2047 }
2048 
std(const Tensor & self,bool unbiased)2049 Tensor std(const Tensor& self, bool unbiased) {
2050   return at::std(
2051       self, /*dim=*/std::nullopt, /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
2052 }
2053 
std(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)2054 Tensor std(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
2055   return at::std(self, dim,
2056                  /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0), keepdim);
2057 }
2058 
std_out(const Tensor & self,at::OptionalIntArrayRef opt_dim,bool unbiased,bool keepdim,Tensor & result)2059 Tensor& std_out(const Tensor& self, at::OptionalIntArrayRef opt_dim, bool unbiased, bool keepdim, Tensor& result) {
2060   return at::std_out(result, self, opt_dim,
2061                      /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0), keepdim);
2062 }
2063 
std(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim)2064 Tensor std(const Tensor& self, at::OptionalIntArrayRef dim,
2065            const std::optional<Scalar>& correction, bool keepdim) {
2066   Tensor result = at::empty({0}, options_to_value_type(self.options()));
2067   return std_var_out("std", result, self, dim, correction, keepdim, true);
2068 }
2069 
std_out(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim,Tensor & result)2070 Tensor& std_out(
2071     const Tensor& self, at::OptionalIntArrayRef dim,
2072     const std::optional<Scalar>& correction, bool keepdim, Tensor& result) {
2073   return std_var_out("std", result, self, dim, correction, keepdim, true);
2074 }
2075 
var_out(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim,Tensor & result)2076 Tensor& var_out(
2077     const Tensor& self, at::OptionalIntArrayRef dim,
2078     const std::optional<Scalar>& correction, bool keepdim, Tensor& result) {
2079   return std_var_out("var", result, self, dim, correction, keepdim, false);
2080 }
2081 
var(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim)2082 Tensor var(
2083     const Tensor& self, at::OptionalIntArrayRef dim,
2084     const std::optional<Scalar>& correction, bool keepdim) {
2085   Tensor result = at::empty({0}, options_to_value_type(self.options()));
2086   return std_var_out("var", result, self, dim, correction, keepdim, false);
2087 }
2088 
std(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim)2089 Tensor std(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
2090   return at::std(self, dimnames_to_positions(self, dim), unbiased, keepdim);
2091 }
2092 
std_out(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim,Tensor & result)2093 Tensor& std_out(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim, Tensor& result) {
2094   return at::std_out(result, self, dimnames_to_positions(self, dim), unbiased, keepdim);
2095 }
2096 
var(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim)2097 Tensor var(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
2098   return at::var(self, dimnames_to_positions(self, dim), unbiased, keepdim);
2099 }
2100 
var_out(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim,Tensor & result)2101 Tensor& var_out(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim, Tensor& result) {
2102   return at::var_out(
2103       result, self, dimnames_to_positions(self, dim), unbiased, keepdim);
2104 }
2105 
var_mean(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim)2106 std::tuple<Tensor,Tensor> var_mean(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
2107   return at::var_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim);
2108 }
2109 
std_mean(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim)2110 std::tuple<Tensor,Tensor> std_mean(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
2111   return at::std_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim);
2112 }
2113 
std(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim)2114 Tensor std(const Tensor& self, DimnameList dim, const std::optional<Scalar>& correction, bool keepdim) {
2115   return at::std(self, dimnames_to_positions(self, dim), correction, keepdim);
2116 }
2117 
std_out(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim,Tensor & result)2118 Tensor& std_out(const Tensor& self, DimnameList dim, const std::optional<Scalar>& correction,
2119                 bool keepdim, Tensor& result) {
2120   return at::std_out(result, self, dimnames_to_positions(self, dim), correction, keepdim);
2121 }
2122 
var(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim)2123 Tensor var(const Tensor& self, DimnameList dim, const std::optional<Scalar>& correction, bool keepdim) {
2124   return at::var(self, dimnames_to_positions(self, dim), correction, keepdim);
2125 }
2126 
var_out(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim,Tensor & result)2127 Tensor& var_out(const Tensor& self, DimnameList dim, const std::optional<Scalar>& correction,
2128                 bool keepdim, Tensor& result) {
2129   return at::var_out(
2130       result, self, dimnames_to_positions(self, dim), correction, keepdim);
2131 }
2132 
var_mean(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim)2133 std::tuple<Tensor,Tensor> var_mean(const Tensor& self, DimnameList dim,
2134                                    const std::optional<Scalar>& correction, bool keepdim) {
2135   return at::var_mean(self, dimnames_to_positions(self, dim), correction, keepdim);
2136 }
2137 
std_mean(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim)2138 std::tuple<Tensor,Tensor> std_mean(const Tensor& self, DimnameList dim,
2139                                    const std::optional<Scalar>& correction, bool keepdim) {
2140   return at::std_mean(self, dimnames_to_positions(self, dim), correction, keepdim);
2141 }
2142 
norm_out(const Tensor & self,const std::optional<Scalar> & p,DimnameList dim,bool keepdim,ScalarType dtype,Tensor & result)2143 Tensor& norm_out(const Tensor& self, const std::optional<Scalar>& p, DimnameList dim, bool keepdim, ScalarType dtype, Tensor& result) {
2144   return at::norm_out(result, self, p, dimnames_to_positions(self, dim), keepdim, dtype);
2145 }
2146 
norm_out(const Tensor & self,const std::optional<Scalar> & p,DimnameList dim,bool keepdim,Tensor & result)2147 Tensor& norm_out(const Tensor& self, const std::optional<Scalar>& p, DimnameList dim, bool keepdim, Tensor& result) {
2148   return at::norm_out(result, self, p, dimnames_to_positions(self, dim), keepdim);
2149 }
2150 
norm(const Tensor & self,const std::optional<Scalar> & p,DimnameList dim,bool keepdim,ScalarType dtype)2151 Tensor norm(const Tensor& self, const std::optional<Scalar>& p, DimnameList dim, bool keepdim, ScalarType dtype) {
2152   return at::norm(self, p, dimnames_to_positions(self, dim), keepdim, dtype);
2153 }
2154 
norm(const Tensor & self,const std::optional<Scalar> & p,DimnameList dim,bool keepdim)2155 Tensor norm(const Tensor& self, const std::optional<Scalar>& p, DimnameList dim, bool keepdim) {
2156   return at::norm(self, p, dimnames_to_positions(self, dim), keepdim);
2157 }
2158 
any(const Tensor & self,Dimname dim,bool keepdim)2159 Tensor any(const Tensor& self, Dimname dim, bool keepdim) {
2160   reportNYIDimnameOverload("any");
2161 }
any_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & result)2162 Tensor& any_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
2163   reportNYIDimnameOverload("any");
2164 }
all(const Tensor & self,Dimname dim,bool keepdim)2165 Tensor all(const Tensor& self, Dimname dim, bool keepdim) {
2166   reportNYIDimnameOverload("all");
2167 }
all_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & result)2168 Tensor& all_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
2169   reportNYIDimnameOverload("all");
2170 }
_is_all_true(const Tensor & self)2171 Tensor _is_all_true(const Tensor& self) {
2172   TORCH_INTERNAL_ASSERT(self.scalar_type() == at::kBool);
2173   return self.all();
2174 }
_is_any_true(const Tensor & self)2175 Tensor _is_any_true(const Tensor& self) {
2176   TORCH_INTERNAL_ASSERT(self.scalar_type() == at::kBool);
2177   return self.any();
2178 }
logcumsumexp(const Tensor & self,Dimname dim)2179 Tensor logcumsumexp(const Tensor& self, Dimname dim) {
2180   return at::logcumsumexp(self, dimname_to_position(self, dim));
2181 }
logcumsumexp_out(const Tensor & self,Dimname dim,Tensor & result)2182 Tensor& logcumsumexp_out(const Tensor& self, Dimname dim, Tensor& result) {
2183   return at::logcumsumexp_out(result, self, dimname_to_position(self, dim));
2184 }
cumsum(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)2185 Tensor cumsum(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
2186   return at::cumsum(self, dimname_to_position(self, dim), dtype);
2187 }
cumsum_(Tensor & self,Dimname dim,std::optional<ScalarType> dtype)2188 Tensor& cumsum_(Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
2189   return at::cumsum_out(self, self, dimname_to_position(self, dim), dtype);
2190 }
cumsum_out(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype,Tensor & result)2191 Tensor& cumsum_out(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype, Tensor& result) {
2192   return at::cumsum_out(result, self, dimname_to_position(self, dim), dtype);
2193 }
cumprod(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)2194 Tensor cumprod(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
2195   return at::cumprod(self, dimname_to_position(self, dim), dtype);
2196 }
cumprod_(Tensor & self,Dimname dim,std::optional<ScalarType> dtype)2197 Tensor& cumprod_(Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
2198   return at::cumprod_out(self, self, dimname_to_position(self, dim), dtype);
2199 }
cumprod_out(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype,Tensor & result)2200 Tensor& cumprod_out(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype, Tensor& result) {
2201   return at::cumprod_out(result, self, dimname_to_position(self, dim), dtype);
2202 }
cummax(const Tensor & self,Dimname dim)2203 std::tuple<Tensor, Tensor> cummax(const Tensor& self, Dimname dim) {
2204   return at::cummax(self, dimname_to_position(self, dim));
2205 }
cummax_out(const Tensor & self,Dimname dim,Tensor & values,Tensor & indices)2206 std::tuple<Tensor&, Tensor&> cummax_out(const Tensor& self, Dimname dim, Tensor& values, Tensor& indices) {
2207   return at::cummax_out(values, indices, self, dimname_to_position(self, dim));
2208 }
cummin(const Tensor & self,Dimname dim)2209 std::tuple<Tensor, Tensor> cummin(const Tensor& self, Dimname dim) {
2210   return at::cummin(self, dimname_to_position(self, dim));
2211 }
cummin_out(const Tensor & self,Dimname dim,Tensor & values,Tensor & indices)2212 std::tuple<Tensor&, Tensor&> cummin_out(const Tensor& self, Dimname dim, Tensor& values, Tensor& indices) {
2213   return at::cummin_out(values, indices, self, dimname_to_position(self, dim));
2214 }
2215 
dist(const Tensor & self,const Tensor & other,const Scalar & p)2216 Tensor dist(const Tensor &self, const Tensor& other, const Scalar& p){
2217   return at::norm(self - other, p);
2218 }
2219 
cpu_equal(const Tensor & self,const Tensor & other)2220 bool cpu_equal(const Tensor& self, const Tensor& other) {
2221   if (!at::namedinference::are_names_equal(
2222         self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) {
2223     return false;
2224   }
2225   at::NoNamesGuard guard;
2226   TORCH_CHECK(self.device() == other.device(), "Cannot compare two tensors on "
2227               "different devices. Got: ", self.device(), " and ", other.device());
2228   if (!self.is_same_size(other)) {
2229     return false;
2230   }
2231   // Since the flags like neg/conj should be already handled outside the
2232   // TensorIterator, it should be safe to have the following fast path by
2233   // ensuring the storage and strides exactly the same.
2234   if (self.is_alias_of(other)
2235       && self.storage_offset() == other.storage_offset()
2236       && self.dtype() == other.dtype()
2237       && self.is_contiguous() == other.is_contiguous()
2238       && self.strides().equals(other.strides())
2239       // Extra checks to ensure the safety in case cpu_equal is directly called in C++.
2240       && self.layout() == other.layout()
2241       && self.is_neg() == other.is_neg()
2242       && self.is_conj() == other.is_conj()) {
2243     if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
2244       return true;
2245     }
2246     std::atomic<bool> result{true};
2247     auto iter = TensorIteratorConfig().add_const_input(self).build();
2248     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "equal_notnan_cpu", [&] {
2249       iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
2250         if (!result) {
2251             return;
2252         }
2253         char* self_data = data[0];
2254         for (C10_UNUSED const auto i : c10::irange(dim_size)) {
2255           if (isnan_(c10::load<scalar_t>(self_data))) {
2256             result = false;
2257             return;
2258           }
2259           self_data += strides[0];
2260         }
2261       });
2262     });
2263     return result.load();
2264   }
2265 
2266   std::atomic<bool> result{true};
2267   auto iter = TensorIteratorConfig()
2268     .add_const_input(self)
2269     .add_const_input(other)
2270     .allow_cpu_scalars(true)
2271     .promote_inputs_to_common_dtype(true)
2272     .build();
2273 
2274   AT_DISPATCH_V2(iter.input_dtype(), "equal_cpu", AT_WRAP([&] {
2275     iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
2276       if (!result) {
2277           return;
2278       }
2279       char* self_data = data[0];
2280       char* other_data = data[1];
2281       for (C10_UNUSED const auto i : c10::irange(dim_size)) {
2282         if (c10::load<scalar_t>(self_data) != c10::load<scalar_t>(other_data)) {
2283           result = false;
2284           return;
2285         }
2286         self_data += strides[0];
2287         other_data += strides[1];
2288       }
2289     });
2290   }), kBool, kBFloat16, kHalf, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
2291   return result.load();
2292 }
2293 
2294 // max(dim), min(dim), topk(dim), mode(dim), are examples of reduction
2295 // functions that select values. value_selecting_reduction_backward is the
2296 // backward function for those operators; it propagates the grad to the
2297 // specific value locations referred to at `indices`.
value_selecting_reduction_backward_symint(const Tensor & grad,int64_t dim,const Tensor & indices,c10::SymIntArrayRef sizes,bool keepdim)2298 Tensor value_selecting_reduction_backward_symint(const Tensor& grad, int64_t dim, const Tensor& indices, c10::SymIntArrayRef sizes, bool keepdim) {
2299   auto inplace_scatter_if_not_tensor_subclass =
2300       [&](const Tensor& grad_out, const Tensor& indices_) {
2301         auto grad_in = at::zeros_symint(sizes, grad_out.options());
2302         if (areAnyTensorSubclassLike({grad, indices})) {
2303           return grad_in.scatter(dim, indices_, grad_out);
2304         }
2305         return grad_in.scatter_(dim, indices_, grad_out);
2306       };
2307 
2308   if (!keepdim && !sizes.empty()) {
2309     auto grad_ = grad.unsqueeze(dim);
2310     auto indices_ = indices.unsqueeze(dim);
2311     return inplace_scatter_if_not_tensor_subclass(grad_, indices_);
2312   }
2313   return inplace_scatter_if_not_tensor_subclass(grad, indices);
2314 }
2315 
sum_csr(const Tensor & self,std::optional<ScalarType> dtype)2316 Tensor sum_csr(const Tensor &self, std::optional<ScalarType> dtype) {
2317   return self.values().sum(dtype);
2318 }
2319 
sum_coo(const Tensor & self,std::optional<ScalarType> dtype)2320 Tensor sum_coo(const Tensor &self, std::optional<ScalarType> dtype) {
2321   return self._values().sum(dtype);
2322 }
2323 
sum_sparse_coo(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)2324 Tensor sum_sparse_coo(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, std::optional<ScalarType> dtype) {
2325   Tensor result;
2326   if (dim.has_value()) {
2327     if (dtype.has_value()) {
2328       result = at::_sparse_sum(self, *dim, *dtype);
2329     } else {
2330       if (c10::isIntegralType(self.scalar_type(), true)) {
2331         result = at::_sparse_sum(self, *dim, at::kLong);
2332       } else {
2333         result = at::_sparse_sum(self, *dim);
2334       }
2335     }
2336   } else {
2337     result = sum_coo(self, dtype);
2338   }
2339   if (keepdim) {
2340     auto dim_mask = make_dim_mask(dim, self.dim());
2341     for (int dim = 0; dim < self.dim(); dim++) {
2342       if (dim_mask[dim]) {
2343         result = result.unsqueeze(dim);
2344       }
2345     }
2346   }
2347   return result;
2348 }
2349 
sum_sparse_compressed(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)2350 Tensor sum_sparse_compressed(
2351     const Tensor& self,
2352     at::OptionalIntArrayRef dim,
2353     bool keepdim,
2354     std::optional<ScalarType> dtype) {
2355   // TODO: The signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype is a little
2356   // bit different in the second parameters `dim`, which causes the conversion of `dim`
2357   // to call into `_sparse_csr_sum`. Align the signatures would be a better choice.
2358   TORCH_CHECK(
2359       dim.has_value(), "dim has no value, cannot be used in sum.dim_IntList");
2360   auto layout = self.layout();
2361   TORCH_CHECK(
2362       layout == kSparseCsr,
2363       "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout ",
2364       layout)
2365   return at::_sparse_csr_sum(self, *dim, keepdim, dtype);
2366 }
2367 
2368 } // namespace at::native
2369