1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/ReduceOps.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/Dispatch_v2.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/WrapDimUtils.h>
10 #include <ATen/WrapDimUtilsMulti.h>
11 #include <ATen/TensorIterator.h>
12 #include <ATen/TensorOperators.h>
13 #include <ATen/NamedTensorUtils.h>
14 #include <ATen/native/ReduceOpsUtils.h>
15 #include <ATen/native/Resize.h>
16 #include <ATen/native/TensorDimApply.h>
17 #include <ATen/core/grad_mode.h>
18 #include <ATen/TensorSubclassLikeUtils.h>
19
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/_cummax_helper.h>
25 #include <ATen/ops/_cummax_helper_native.h>
26 #include <ATen/ops/_cummin_helper.h>
27 #include <ATen/ops/_cummin_helper_native.h>
28 #include <ATen/ops/_is_all_true_native.h>
29 #include <ATen/ops/_is_any_true_native.h>
30 #include <ATen/ops/_logcumsumexp.h>
31 #include <ATen/ops/_logcumsumexp_native.h>
32 #include <ATen/ops/_sparse_csr_sum.h>
33 #include <ATen/ops/_sparse_sum.h>
34 #include <ATen/ops/_sparse_sum_native.h>
35 #include <ATen/ops/_to_copy.h>
36 #include <ATen/ops/add.h>
37 #include <ATen/ops/all_meta.h>
38 #include <ATen/ops/all_native.h>
39 #include <ATen/ops/amax.h>
40 #include <ATen/ops/amax_meta.h>
41 #include <ATen/ops/amax_native.h>
42 #include <ATen/ops/amin_meta.h>
43 #include <ATen/ops/amin_native.h>
44 #include <ATen/ops/aminmax_meta.h>
45 #include <ATen/ops/aminmax_native.h>
46 #include <ATen/ops/any_meta.h>
47 #include <ATen/ops/any_native.h>
48 #include <ATen/ops/argmax_meta.h>
49 #include <ATen/ops/argmax_native.h>
50 #include <ATen/ops/argmin_meta.h>
51 #include <ATen/ops/argmin_native.h>
52 #include <ATen/ops/cat.h>
53 #include <ATen/ops/complex.h>
54 #include <ATen/ops/cummax.h>
55 #include <ATen/ops/cummax_native.h>
56 #include <ATen/ops/cummaxmin_backward_native.h>
57 #include <ATen/ops/cummin.h>
58 #include <ATen/ops/cummin_native.h>
59 #include <ATen/ops/cumprod.h>
60 #include <ATen/ops/cumprod_backward_native.h>
61 #include <ATen/ops/cumprod_meta.h>
62 #include <ATen/ops/cumprod_native.h>
63 #include <ATen/ops/cumsum.h>
64 #include <ATen/ops/cumsum_meta.h>
65 #include <ATen/ops/cumsum_native.h>
66 #include <ATen/ops/diff_native.h>
67 #include <ATen/ops/dist_native.h>
68 #include <ATen/ops/empty.h>
69 #include <ATen/ops/empty_like.h>
70 #include <ATen/ops/equal_native.h>
71 #include <ATen/ops/exp.h>
72 #include <ATen/ops/gather.h>
73 #include <ATen/ops/gradient_native.h>
74 #include <ATen/ops/imag.h>
75 #include <ATen/ops/isnan_native.h>
76 #include <ATen/ops/linalg_vector_norm.h>
77 #include <ATen/ops/logcumsumexp.h>
78 #include <ATen/ops/logcumsumexp_native.h>
79 #include <ATen/ops/logical_xor.h>
80 #include <ATen/ops/logsumexp.h>
81 #include <ATen/ops/logsumexp_native.h>
82 #include <ATen/ops/mean.h>
83 #include <ATen/ops/mean_meta.h>
84 #include <ATen/ops/mean_native.h>
85 #include <ATen/ops/nanmean_native.h>
86 #include <ATen/ops/nansum.h>
87 #include <ATen/ops/nansum_native.h>
88 #include <ATen/ops/narrow.h>
89 #include <ATen/ops/native_norm.h>
90 #include <ATen/ops/ne.h>
91 #include <ATen/ops/norm.h>
92 #include <ATen/ops/norm_meta.h>
93 #include <ATen/ops/norm_native.h>
94 #include <ATen/ops/ones.h>
95 #include <ATen/ops/prod.h>
96 #include <ATen/ops/prod_meta.h>
97 #include <ATen/ops/prod_native.h>
98 #include <ATen/ops/real.h>
99 #include <ATen/ops/slice.h>
100 #include <ATen/ops/special_logsumexp_native.h>
101 #include <ATen/ops/sqrt.h>
102 #include <ATen/ops/squeeze.h>
103 #include <ATen/ops/stack.h>
104 #include <ATen/ops/std.h>
105 #include <ATen/ops/std_mean.h>
106 #include <ATen/ops/std_mean_native.h>
107 #include <ATen/ops/std_native.h>
108 #include <ATen/ops/sub.h>
109 #include <ATen/ops/sum.h>
110 #include <ATen/ops/sum_meta.h>
111 #include <ATen/ops/sum_native.h>
112 #include <ATen/ops/trace_native.h>
113 #include <ATen/ops/value_selecting_reduction_backward_native.h>
114 #include <ATen/ops/var.h>
115 #include <ATen/ops/var_mean.h>
116 #include <ATen/ops/var_mean_native.h>
117 #include <ATen/ops/var_native.h>
118 #include <ATen/ops/zeros.h>
119 #include <ATen/ops/zeros_like.h>
120 #endif
121
122 #include <c10/util/irange.h>
123 #include <c10/util/SmallBuffer.h>
124
125 #include <algorithm>
126 #include <cmath>
127 #include <functional>
128 #include <limits>
129 #include <numeric>
130 #include <type_traits>
131 #include <utility>
132 #include <vector>
133
134 namespace at::meta {
135
infer_dtype_from_optional(const Tensor & self,const std::optional<ScalarType> & opt_dtype,const Tensor & result)136 static ScalarType infer_dtype_from_optional(
137 const Tensor& self,
138 const std::optional<ScalarType>& opt_dtype,
139 const Tensor& result) {
140 // 'opt_dtype' has the priority for both cases.
141 if (result.defined()) {
142 // Otherwise, get the result type, if defined.
143 return opt_dtype.value_or(result.scalar_type());
144 } else {
145 // Last case is to get the self type.
146 // If the self type is an integer, we promote it to kLong.
147 return at::native::get_dtype_from_self(self, opt_dtype, true);
148 }
149 }
150
optional_to_arrayref(const std::optional<int64_t> & opt)151 static IntArrayRef optional_to_arrayref(const std::optional<int64_t>& opt) {
152 return opt.has_value() ? opt.value() : IntArrayRef{};
153 }
154
get_result_or_bytebool_dtype(const Tensor & self,const Tensor & result)155 static ScalarType get_result_or_bytebool_dtype(const Tensor& self, const Tensor& result) {
156 // Refer [all, any : uint8 compatibility]
157 if (result.defined()) {
158 return result.scalar_type();
159 } else {
160 return (self.scalar_type() == kByte) ? kByte : kBool;
161 }
162 }
163
check_result_is_bytebool(const char * name,const Tensor & self,const Tensor & result)164 static void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) {
165 if (result.defined()) {
166 // Refer [all, any : uint8 compatibility]
167 TORCH_CHECK(
168 result.scalar_type() == ScalarType::Bool ||
169 result.scalar_type() == ScalarType::Byte,
170 name, " only supports bool tensor for result, got: ",
171 result.scalar_type());
172 }
173 }
174
175 // Note [all, any : uint8 compatibility]:
176 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177 // For NumPy compatibility, `all` and `any` return
178 // Tensor of dtype `bool`. However for compatibility reason,
179 // for `uint8`, they return Tensor of same dtype `uint8`.
180 // Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
allany_meta(impl::MetaBase & meta,const char * name,const Tensor & self,OptionalIntArrayRef dims,bool keepdim)181 static void allany_meta(
182 impl::MetaBase& meta,
183 const char* name,
184 const Tensor& self,
185 OptionalIntArrayRef dims,
186 bool keepdim) {
187 const auto& result = meta.maybe_get_output();
188 check_result_is_bytebool(name, self, result);
189 auto out_dtype = get_result_or_bytebool_dtype(self, result);
190 resize_reduction(meta, self, dims, keepdim, out_dtype, /*allow_empty_dims=*/true);
191 }
192
TORCH_META_FUNC2(all,dim)193 TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
194 allany_meta(*this, "all", self, dim, keepdim);
195 }
196
TORCH_META_FUNC2(all,dims)197 TORCH_META_FUNC2(all, dims)(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) {
198 allany_meta(*this, "all", self, dim, keepdim);
199 }
200
TORCH_META_FUNC(all)201 TORCH_META_FUNC(all)(const Tensor& self) {
202 allany_meta(*this, "all", self, {}, false);
203 }
204
TORCH_META_FUNC2(any,dim)205 TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
206 allany_meta(*this, "any", self, dim, keepdim);
207 }
208
TORCH_META_FUNC2(any,dims)209 TORCH_META_FUNC2(any, dims)(const Tensor& self, OptionalIntArrayRef dim, bool keepdim) {
210 allany_meta(*this, "any", self, dim, keepdim);
211 }
212
TORCH_META_FUNC(any)213 TORCH_META_FUNC(any)(const Tensor& self) {
214 allany_meta(*this, "any", self, {}, false);
215 }
216
check_argmax_argmin(const char * name,const Tensor & self,const std::optional<int64_t> & dim)217 static void check_argmax_argmin(
218 const char* name,
219 const Tensor& self,
220 const std::optional<int64_t>& dim) {
221 if (dim.has_value()) {
222 auto dim_ = maybe_wrap_dim(dim.value(), self.dim());
223 native::zero_numel_check_dims(self, dim_, name);
224 } else {
225 TORCH_CHECK_INDEX(
226 self.numel() != 0,
227 name, ": Expected reduction dim to be specified for input.numel() == 0.");
228 }
229 }
230
TORCH_META_FUNC(argmax)231 TORCH_META_FUNC(argmax)
232 (const Tensor& self, std::optional<int64_t> dim, bool keepdim) {
233 check_argmax_argmin("argmax()", self, dim);
234 resize_reduction(*this, self, optional_to_arrayref(dim), keepdim, kLong);
235 }
236
TORCH_META_FUNC(argmin)237 TORCH_META_FUNC(argmin)
238 (const Tensor& self, std::optional<int64_t> dim, bool keepdim) {
239 check_argmax_argmin("argmin()", self, dim);
240 resize_reduction(*this, self, optional_to_arrayref(dim), keepdim, kLong);
241 }
242
meta_func_cum_ops(impl::MetaBase & meta,const char * name,const Tensor & self,int64_t dim,std::optional<ScalarType> dtype)243 static void meta_func_cum_ops(
244 impl::MetaBase& meta,
245 const char* name,
246 const Tensor& self,
247 int64_t dim,
248 std::optional<ScalarType> dtype) {
249 // Checking whether 'dim' is valid.
250 maybe_wrap_dim(dim, self.dim());
251
252 const auto& result = meta.maybe_get_output();
253 ScalarType out_dtype{};
254
255 if (result.defined()) {
256 out_dtype = dtype.value_or(result.scalar_type());
257 } else {
258 auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
259 out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
260 }
261
262 meta.set_output_raw_strided(0, self.sizes(), {}, self.options().dtype(out_dtype));
263 namedinference::propagate_names(result, self);
264 }
265
TORCH_META_FUNC(cumsum)266 TORCH_META_FUNC(cumsum)
267 (const Tensor& self, int64_t dim, std::optional<ScalarType> dtype) {
268 meta_func_cum_ops(*this, "cumsum", self, dim, dtype);
269 }
270
TORCH_META_FUNC(cumprod)271 TORCH_META_FUNC(cumprod)
272 (const Tensor& self, int64_t dim, std::optional<ScalarType> dtype) {
273 meta_func_cum_ops(*this, "cumprod", self, dim, dtype);
274 }
275
TORCH_META_FUNC2(sum,dim_IntList)276 TORCH_META_FUNC2(sum, dim_IntList)
277 (const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
278 auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
279 resize_reduction(*this, self, opt_dim, keepdim, out_dtype);
280 }
281
TORCH_META_FUNC2(prod,dim_int)282 TORCH_META_FUNC2(prod, dim_int)
283 (const Tensor& self,
284 int64_t dim,
285 bool keepdim,
286 std::optional<ScalarType> dtype) {
287 auto out_dtype = infer_dtype_from_optional(self, dtype, maybe_get_output());
288 resize_reduction(*this, self, dim, keepdim, out_dtype);
289 }
290
TORCH_META_FUNC2(mean,dim)291 TORCH_META_FUNC2(mean, dim)
292 (const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
293 auto in_dtype = at::native::get_dtype_from_self(self, opt_dtype, true);
294
295 if (!at::isFloatingType(in_dtype) && !at::isComplexType(in_dtype)) {
296 std::string what = "Input";
297 std::string dtype = toString(self.scalar_type());
298
299 if (opt_dtype.has_value()) {
300 what = "Optional";
301 dtype = toString(opt_dtype.value());
302 }
303
304 TORCH_CHECK(
305 false,
306 "mean(): could not infer output dtype. ",
307 what, " dtype must be either a floating point or complex dtype. ",
308 "Got: ", dtype);
309 }
310
311 auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
312 resize_reduction(*this, self, opt_dim, keepdim, out_dtype);
313 }
314
get_result_or_self_value_dtype(const Tensor & self,const Tensor & result,const std::optional<ScalarType> & dtype)315 static ScalarType get_result_or_self_value_dtype(
316 const Tensor& self,
317 const Tensor& result,
318 const std::optional<ScalarType>& dtype) {
319 if (result.defined()) {
320 return result.scalar_type();
321 } else {
322 return dtype.value_or(toRealValueType(self.scalar_type()));
323 }
324 }
325
TORCH_META_FUNC2(norm,ScalarOpt_dim)326 TORCH_META_FUNC2(norm, ScalarOpt_dim)
327 (const Tensor& self, const OptionalScalarRef p, IntArrayRef dim, bool keepdim) {
328 TORCH_CHECK(
329 at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
330 "norm(): input dtype should be either floating point or complex. "
331 "Got ", self.scalar_type(), " instead.");
332
333 auto out_dtype = get_result_or_self_value_dtype(self, maybe_get_output(), std::nullopt);
334 resize_reduction(*this, self, dim, keepdim, out_dtype);
335 }
336
TORCH_META_FUNC2(norm,ScalarOpt_dim_dtype)337 TORCH_META_FUNC2(norm, ScalarOpt_dim_dtype)
338 (const Tensor& self,
339 const OptionalScalarRef p,
340 IntArrayRef dim,
341 bool keepdim,
342 ScalarType dtype) {
343 TORCH_CHECK(
344 at::isFloatingType(dtype) || at::isComplexType(dtype),
345 "norm(): the desired output dtype should be either floating point or complex. "
346 "Got ", dtype, " instead.");
347
348 auto out_dtype = get_result_or_self_value_dtype(self, maybe_get_output(), dtype);
349 resize_reduction(*this, self, dim, keepdim, out_dtype);
350 }
351
TORCH_META_FUNC(aminmax)352 TORCH_META_FUNC(aminmax)
353 (const Tensor& self, std::optional<int64_t> dim_opt, bool keepdim) {
354 DimVector shape;
355 if (dim_opt.has_value()) {
356 auto dim = maybe_wrap_dim(dim_opt.value(), self.ndimension());
357 native::zero_numel_check_dims(self, dim, "aminmax");
358 shape = get_reduction_shape(self, dim, keepdim);
359 } else {
360 TORCH_CHECK(
361 self.numel() > 0,
362 "aminmax(): cannot compute aminmax over an empty dimension as the "
363 "operation has no identity.");
364 if (keepdim) {
365 shape = DimVector(self.ndimension(), 1);
366 }
367 }
368 const auto options = self.options();
369 this->set_output_raw_strided(0, shape, {}, options);
370 this->set_output_raw_strided(1, shape, {}, options);
371 }
372
TORCH_META_FUNC(amax)373 TORCH_META_FUNC(amax)
374 (const Tensor& self, IntArrayRef dim, bool keepdim) {
375 auto maybe_result = maybe_get_output();
376 if (maybe_result.defined()) {
377 TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ",
378 self.scalar_type(), " for input's dtype and ", maybe_result.scalar_type(), " for out's dtype.");
379 }
380 if (self.numel() == 0) {
381 at::native::zero_numel_check_dims(self, dim, "amax()");
382 }
383 const ScalarType& out_dtype = maybe_result.defined() ? maybe_result.scalar_type() : self.scalar_type();
384 resize_reduction(*this, self, dim, keepdim, out_dtype);
385 }
386
TORCH_META_FUNC(amin)387 TORCH_META_FUNC(amin)
388 (const Tensor& self, IntArrayRef dim, bool keepdim) {
389 auto maybe_result = maybe_get_output();
390 if (maybe_result.defined()) {
391 TORCH_CHECK(self.scalar_type() == maybe_result.scalar_type(), "Expected the dtype for input and out to match, but got ",
392 self.scalar_type(), " for input's dtype and ", maybe_result.scalar_type(), " for out's dtype.");
393 }
394 if (self.numel() == 0) {
395 at::native::zero_numel_check_dims(self, dim, "amin()");
396 }
397 const ScalarType& out_dtype = maybe_result.defined() ? maybe_result.scalar_type() : self.scalar_type();
398 resize_reduction(*this, self, dim, keepdim, out_dtype);
399 }
400
401 } // namespace at::meta
402
403 namespace at::native {
404
405 DEFINE_DISPATCH(aminmax_stub);
406 DEFINE_DISPATCH(aminmax_allreduce_stub);
407
TORCH_IMPL_FUNC(aminmax_out)408 TORCH_IMPL_FUNC(aminmax_out)
409 (const Tensor& self,
410 std::optional<int64_t> dim_opt,
411 bool keepdim,
412 const Tensor& min,
413 const Tensor& max) {
414 auto mutable_min = const_cast<Tensor&>(min);
415 auto mutable_max = const_cast<Tensor&>(max);
416 if (dim_opt.has_value()) {
417 aminmax_stub(
418 self.device().type(),
419 self,
420 maybe_wrap_dim(dim_opt.value(), self.ndimension()),
421 keepdim,
422 mutable_min,
423 mutable_max);
424 } else {
425 aminmax_allreduce_stub(self.device().type(), self.contiguous(), mutable_min, mutable_max);
426 }
427 }
428
429 DEFINE_DISPATCH(sum_stub);
430 DEFINE_DISPATCH(nansum_stub);
431 DEFINE_DISPATCH(std_var_stub);
432 DEFINE_DISPATCH(prod_stub);
433 DEFINE_DISPATCH(norm_stub);
434 DEFINE_DISPATCH(mean_stub);
435 DEFINE_DISPATCH(and_stub);
436 DEFINE_DISPATCH(or_stub);
437 DEFINE_DISPATCH(min_values_stub);
438 DEFINE_DISPATCH(max_values_stub);
439 DEFINE_DISPATCH(argmax_stub);
440 DEFINE_DISPATCH(argmin_stub);
441 DEFINE_DISPATCH(cumsum_stub);
442 DEFINE_DISPATCH(cumprod_stub);
443 DEFINE_DISPATCH(logcumsumexp_stub);
444
_logcumsumexp_cpu(const Tensor & self,int64_t dim)445 Tensor _logcumsumexp_cpu(const Tensor& self, int64_t dim) {
446 Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
447 return _logcumsumexp_out_cpu(self, dim, result);
448 }
449
_logcumsumexp_out_cpu(const Tensor & self,int64_t dim,Tensor & result)450 Tensor& _logcumsumexp_out_cpu(const Tensor& self, int64_t dim, Tensor& result) {
451 logcumsumexp_stub(self.device().type(), result, self, dim);
452 return result;
453 }
454
logcumsumexp(const Tensor & self,int64_t dim)455 Tensor logcumsumexp(const Tensor& self, int64_t dim) {
456 auto result = [&]() {
457 NoNamesGuard guard;
458 return at::_logcumsumexp(self, dim);
459 }();
460 namedinference::propagate_names(result, self);
461 return result;
462 }
463
logcumsumexp_out(const Tensor & self,int64_t dim,Tensor & result)464 Tensor& logcumsumexp_out(const Tensor& self, int64_t dim, Tensor& result) {
465 check_scalar_type_device_layout_equal(result, self);
466 {
467 NoNamesGuard guard;
468 at::_logcumsumexp_out(result, self.toType(result.scalar_type()), dim);
469 }
470 namedinference::propagate_names(result, self);
471 return result;
472 }
473
474 template <class Stub>
impl_func_cum_ops(const Tensor & self,int64_t dim,const Tensor & result,Stub & stub)475 void impl_func_cum_ops(
476 const Tensor& self,
477 int64_t dim,
478 const Tensor& result,
479 Stub& stub) {
480 NoNamesGuard guard;
481 if (self.dim() == 0) {
482 result.fill_(self);
483 } else if (self.numel() == 0) {
484 result.zero_();
485 } else {
486 dim = maybe_wrap_dim(dim, self.dim());
487 stub(self.device().type(), result, self.to(result.scalar_type()), dim);
488 }
489 }
490
TORCH_IMPL_FUNC(cumsum_out)491 TORCH_IMPL_FUNC(cumsum_out)
492 (const Tensor& self,
493 int64_t dim,
494 std::optional<ScalarType> dtype,
495 const Tensor& result) {
496 impl_func_cum_ops(self, dim, result, cumsum_stub);
497 }
498
TORCH_IMPL_FUNC(cumprod_out)499 TORCH_IMPL_FUNC(cumprod_out)
500 (const Tensor& self,
501 int64_t dim,
502 std::optional<ScalarType> dtype,
503 const Tensor& result) {
504 impl_func_cum_ops(self, dim, result, cumprod_stub);
505 }
506
reversed_cumsum(const Tensor & w,int64_t dim)507 static Tensor reversed_cumsum(const Tensor& w, int64_t dim) {
508 return w.flip(dim).cumsum(dim).flip(dim);
509 }
510
cumprod_backward(const Tensor & grad,const Tensor & input,int64_t dim,const Tensor & output)511 Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, const Tensor& output) {
512 /*
513 We show here how to derive an O(n) gradient formula for
514 arbitrary inputs. It follows via a basic application of the
515 chain rule together with a number of observations for different
516 cases. We assume that x is an n-dimensional vector and y = cumprod(x).
517 In the actual implementation we will need to play a bit with masks
518 to be able to implement the formulas deduced here for tensors.
519
520 We will first deduce the formula for the case when
521 x[i] != 0 for 1 <= i <= n.
522
523 For F : R^n -> R the cost function (we will look at the complex case later),
524 we have
525
526 dF / dx_k = sum_j (dF / dy_j) * (dy_j / dx_k) (1)
527
528 The term dF / dy_j is just grad_output[j] (assuming again
529 everything is one-dimensional).
530
531 The term (dy_j / dx_k) is easily seen to be
532
533 if j >= k
534 dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i
535 else:
536 dy_j / dx_k = 0
537
538 Note that the indicator (j>=k) can be taken out
539 by replacing the sum in (1) with a sum from
540 k <= j <= n.
541
542 Thus,
543 dF / dx_k = sum_{k <= j <= n} grad_output[j] * (dy_j / dx_k)
544
545 with
546 dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i (2)
547
548 Note that this last term is just the cumulative product
549 with k omitted. Thus, if x_k (the input) is nonzero, we can
550 just express this as
551
552 dy_j / dx_k = (prod_{1 <= i <= j} x_i) / x_k
553 = y_j / x_k
554
555 So therefore,
556
557 dF / dx_k = sum_{k <= j <= n} grad_output[j] * y_j / x_k
558
559 This formula just makes sense when input[i] != 0 for every i.
560
561 Assume now that there exists at least a zero in the input.
562 Denote by z1 the first element 1 <= z1 <= n with input[z1] = 0
563 and z2 the second element z1 < z2 <= n with input[z2] = 0,
564 (or z2 = n if there is just one zero in input)
565
566 We have three cases.
567
568 k > z1:
569 Looking at (2), we see that dy_j / dx_k = 0, for j >= k, as these terms
570 all include a x_{z1} which is zero. As such, dF / dx_k = 0 in this case
571
572 k < z1:
573 Reasoning as in the previous case, we see that for these elements we have that
574
575 dF / dx_k = sum_{k <= j < z1} grad_output[j] * (dy_j / dx_k)
576
577 as the terms of the sum for j in z1 <= j <= n are all zero
578
579 k = z1:
580 Similar to the case k < z1, we have that
581
582 dF / dx_z1 = sum_{z1 <= j < z2} grad_output[j] * (dy_j / dx_z1)
583
584 This case has a subtlety though. To compute (dy_j / dx_z1), we cannot use the formula
585
586 dy_j / dx_z1 = y_j / x_z1
587
588 as, y_j = x_z1 = 0 for j >= z1. We need to compute it with the formula for its derivative,
589 that is:
590
591 dy_j / dx_z1 = prod(x[:z1]) * (grad_output[z1] + sum(grad_output[z1+1:z2] * cumprod(x[z1+1:z2])))
592
593 When the inputs are complex, this is map is holomorphic. As such, to compute
594 its backwards is just the conjugate of the usual backwards. This simplifies to
595 conjugating the input. We may also reuse the output as, since the map is holomorphic,
596 cumprod(input.conj()) = cumprod(input).conj()
597 */
598
599 if (input.sym_numel() <= 1) {
600 return grad;
601 }
602 dim = at::maybe_wrap_dim(dim, input.dim());
603 const int64_t dim_size = input.sym_sizes()[dim].guard_int(__FILE__, __LINE__);
604 if (dim_size == 1) {
605 return grad;
606 }
607
608 // To enable complex support.
609 // From this line on `input_conj` and output_conj`
610 // are interchangeable with `input` and `output`.
611 auto input_conj = input.conj();
612 auto output_conj = output.conj();
613
614 // For Composite Compliance, we always choose the slower but composite compliant path.
615 bool are_inputs_tensors_sublcass = areAnyTensorSubclassLike({input, grad, output});
616
617 const auto w = output_conj * grad;
618 const auto is_zero = input == 0;
619 if (!are_inputs_tensors_sublcass) {
620 if (is_zero.any().item<uint8_t>() == 0) {
621 return reversed_cumsum(w, dim).div(input_conj);
622 }
623 }
624
625 // If we are not computing a second order gradient, we can use an
626 // O(n) implementation. The derivative of this implementation is _not_
627 // the second derivative of cumprod. As such, we fallback to a less efficient
628 // O(n^2) implementation when at::GradMode::is_enabled().
629 if (!at::GradMode::is_enabled() && !are_inputs_tensors_sublcass) {
630 // n.b. This could probably be implemented much faster with a kernel
631
632 // From here on we need to use some mask gymnastics to
633 // account for the tensorial dimensions
634 // We do a cumsum of the zeros along the dimension.
635 // For a vector is_zero = [False, True, False, True, False]
636 // we would have cumsum = [0, 1, 1, 2, 2]
637 // As such we have (in python code for simplicity)
638 // The mask for the range [0, z1):
639 // cumsum == 0
640 // The indices of the first zero z1 and zeros when
641 // there is no first zero:
642 // indices = (cumsum == 1).max(dim, keepdim=True).indices
643 // The mask for the first zero:
644 // zeros_like(indices).scatter_(dim, indices, 1.) & cumsum == 1
645 // Note that the logic_and with cumsum == 1 accounts
646 // for the case when there is no first zero
647 Tensor grad_input = at::zeros_symint(input.sym_sizes(), grad.options());
648 const auto cumsum = is_zero.cumsum(dim);
649
650 // case k < z1
651 // select everything before the first zero [0, z1)
652 auto mask = cumsum == 0;
653 // equiv to grad_input[mask] = deriv[grad]
654 grad_input.masked_scatter_(mask,
655 reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input_conj).masked_select(mask));
656 // select everything from the first zero to the second zero [z1, z2)
657 mask = cumsum == 1;
658
659 // case k = z1
660 // We start by select the first zero [z1]
661 // We locate the indices of the first zero using the max function
662 // We then go from the indices to a mask index_fill_
663 // When there is no zero in the slice, max will return the index 0.
664 // To account for this, we need to do an intersection with mask,
665 // which is true in the range [z1, z2)
666 const auto first_zero_index = std::get<1>(mask.max(dim, /*keepdim*/ true));
667 const auto first_zero_mask = at::zeros_like(mask)
668 .scatter_(dim, first_zero_index, /*src*/ 1)
669 .logical_and_(mask);
670
671 // select everything between the first zero and the second zero (z1, z2)
672 mask &= ~first_zero_mask;
673 // here we compute
674 // dy_j / dx_z1 = sum(cumprod(input[z1+1:z2] * grad[z1+1:z2])) * prod(output[z1-1])
675 // relu_() necessary as gather does not support negative indices
676 // finally, we do grad_input[z1] = dy_j / dx_z1
677 grad_input.masked_scatter_(first_zero_mask,
678 input_conj.masked_fill(~mask, 1.).cumprod(dim)
679 .mul_(grad.masked_fill(cumsum != 1, 0.))
680 .sum(dim, /*keepdim*/true)
681 .mul_(at::gather(output_conj, dim, (first_zero_index - 1).relu_())
682 .masked_fill_(first_zero_index == 0, 1.))
683 .masked_select(first_zero_mask));
684 return grad_input;
685 } else { // GradMode::enabled()
686 /*
687 If the input is nonzero, we need to calculate the dy_j / dx_k
688 by using the formula (2), called in the code omitted_products.
689
690 The way the code calculates it is simply by noting that
691
692 prod_{1 <= i <= j, i != k} x_i
693 = (prod_{1 <= i <= k} x_i) * (prod_{k + 1 <= i <= j} x_i)
694
695 the first term is calculated as prods_until_k, which since
696 doesn't depend in j is easy to vectorize.
697
698 The second term (indexed by j) is the cumulative product of
699 x_{k+1}, x_{k+2}, ..., x_n, and it's named in the code
700 prods_from_k_pkus_1, and it's calculated as a cumprod.
701
702 In order to vectorize this properly, we need to add to
703 omitted_products the dimensions where k > j, and therefore
704 dy_j / dx_k = 0, which is done right after the assert.
705 */
706
707 Tensor grad_input;
708 // For Composite Compliance, we will use
709 // at::stack on the grad slices, hence the vector.
710 std::vector<Tensor> grad_inputs;
711 if (are_inputs_tensors_sublcass) {
712 grad_inputs.reserve(dim_size);
713 } else {
714 grad_input = at::zeros(input.sizes(), grad.options());
715 }
716 auto ones_size = input.sym_sizes().vec();
717 ones_size[dim] = 1;
718 const Tensor ones = at::ones({1}, grad.options()).expand_symint(ones_size);
719 Tensor prods_from_k_plus_1;
720 Tensor omitted_products;
721 for (const auto k : c10::irange(dim_size)) {
722 if (k == 0) {
723 prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k + 1), dim);
724 omitted_products = at::cat({ones, std::move(prods_from_k_plus_1)}, dim);
725 } else if (k == dim_size - 1) {
726 const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
727 omitted_products = prods_until_k;
728 } else {
729 const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
730 prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k+1), dim);
731 omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
732 omitted_products = at::cat({prods_until_k, omitted_products}, dim);
733 }
734
735 // At this point omitted_products is the same size
736 // as input, except on the dimension dim where it's
737 // dim_size - k
738 TORCH_CHECK(omitted_products.sym_size(dim) == dim_size - k);
739
740 auto grad_slice = at::sum(grad.slice(dim, k) * omitted_products, dim);
741 if (are_inputs_tensors_sublcass) {
742 grad_inputs.push_back(grad_slice);
743 } else {
744 grad_input.select(dim, k).copy_(grad_slice);
745 }
746 }
747
748 return are_inputs_tensors_sublcass ? at::stack(grad_inputs, dim) : std::move(grad_input);
749 }
750 }
751
752 // Implement std::is_nan<IntegralType> for MSVC.
753 namespace {
754 #ifdef _MSC_VER
755 template<typename T>
isnan_(T x)756 inline typename std::enable_if<std::is_integral<T>::value, bool>::type isnan_(T x) {
757 return false;
758 }
759 template<typename T>
isnan_(T x)760 inline typename std::enable_if<!std::is_integral<T>::value, bool>::type isnan_(T x) {
761 return std::isnan(x);
762 }
763 #else
764 template<typename T>
765 inline bool isnan_(T x) {
766 return std::isnan(x);
767 }
768 #endif
769 }
770
771 template<typename T1, typename T2, typename Operation>
cummax_cummin_helper(const T1 * self_data,T1 * values_data,T2 * indices_data,int self_dim_size,int self_stride,int values_stride,int indices_stride)772 void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data,
773 int self_dim_size, int self_stride, int values_stride, int indices_stride) {
774 Operation op;
775 T1 out = c10::load(self_data);
776 int idx = 0;
777 for (const auto i : c10::irange(self_dim_size)) {
778 T1 curr_elem = c10::load(&self_data[i*self_stride]);
779 if(isnan_(curr_elem) || (!isnan_(out) && op(curr_elem, out))) {
780 out = curr_elem;
781 idx = i;
782 }
783 values_data[i*values_stride] = out;
784 indices_data[i*indices_stride] = idx;
785 }
786 }
787
cummax_helper_cpu(const Tensor & self,Tensor & values,Tensor & indices,int64_t dim)788 void cummax_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
789 AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf,
790 self.scalar_type(), "cummax_cpu",
791 [&] {
792 at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::greater_equal<scalar_t>>);
793 });
794 }
795
cummax_out(const Tensor & self,int64_t dim,Tensor & values,Tensor & indices)796 std::tuple<Tensor&, Tensor&> cummax_out(const Tensor& self, int64_t dim, Tensor& values, Tensor& indices) {
797 check_scalar_type_device_layout_equal(values, self);
798 check_scalar_type_device_layout_equal(indices, at::empty({0}, self.options().dtype(at::kLong)));
799 {
800 NoNamesGuard guard;
801 at::native::resize_output(values, self.sizes());
802 at::native::resize_output(indices, self.sizes());
803 if(self.dim() == 0) {
804 values.fill_(self);
805 indices.fill_(0);
806 } else if(self.numel() != 0) {
807 dim = maybe_wrap_dim(dim, self.dim());
808 at::_cummax_helper(self, values, indices, dim);
809 }
810 }
811 namedinference::propagate_names(values, self);
812 namedinference::propagate_names(indices, self);
813 return std::forward_as_tuple(values, indices);
814 }
815
cummax(const Tensor & self,int64_t dim)816 std::tuple<Tensor, Tensor> cummax(const Tensor& self, int64_t dim) {
817 auto values = at::empty(self.sizes(), self.options());
818 auto indices = at::empty(self.sizes(), self.options().dtype(at::kLong));
819 at::cummax_out(values, indices, self, dim);
820 return std::make_tuple(values, indices);
821 }
822
cummin_helper_cpu(const Tensor & self,Tensor & values,Tensor & indices,int64_t dim)823 void cummin_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
824 AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf,
825 self.scalar_type(), "cummin_cpu",
826 [&] {
827 at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::less_equal<scalar_t>>);
828 });
829 }
830
cummin_out(const Tensor & self,int64_t dim,Tensor & values,Tensor & indices)831 std::tuple<Tensor&, Tensor&> cummin_out(const Tensor& self, int64_t dim, Tensor& values, Tensor& indices) {
832 check_scalar_type_device_layout_equal(values, self);
833 check_scalar_type_device_layout_equal(indices, at::empty({0}, self.options().dtype(at::kLong)));
834 {
835 NoNamesGuard guard;
836 at::native::resize_output(values, self.sizes());
837 at::native::resize_output(indices, self.sizes());
838 if(self.dim() == 0) {
839 values.fill_(self);
840 indices.fill_(0);
841 } else if(self.numel() != 0) {
842 dim = maybe_wrap_dim(dim, self.dim());
843 at::_cummin_helper(self, values, indices, dim);
844 }
845 }
846 namedinference::propagate_names(values, self);
847 namedinference::propagate_names(indices, self);
848 return std::forward_as_tuple(values, indices);
849 }
850
cummin(const Tensor & self,int64_t dim)851 std::tuple<Tensor, Tensor> cummin(const Tensor& self, int64_t dim) {
852 auto values = at::empty(self.sizes(), self.options());
853 auto indices = at::empty(self.sizes(), self.options().dtype(at::kLong));
854 at::cummin_out(values, indices, self, dim);
855 return std::make_tuple(values, indices);
856 }
857
cummaxmin_backward(const Tensor & grad,const Tensor & input,const Tensor & indices,int64_t dim)858 Tensor cummaxmin_backward(const Tensor& grad, const Tensor& input, const Tensor& indices, int64_t dim) {
859 if (input.sym_numel() == 0) {
860 return input;
861 }
862 auto result = at::zeros_symint(input.sym_sizes(), input.options());
863
864 // for composite compliance, use out-of-place variant of
865 // `scatter_add` if `indices` or `grad` is a Tensor Subclass.
866 if (areAnyTensorSubclassLike({indices, grad})) {
867 return result.scatter_add(dim, indices, grad);
868 }
869 return result.scatter_add_(dim, indices, grad);
870 }
871
prepend_append_on_dim(const Tensor & self,const std::optional<Tensor> & prepend,const std::optional<Tensor> & append,int64_t dim)872 static Tensor prepend_append_on_dim(const Tensor& self, const std::optional<Tensor>& prepend, const std::optional<Tensor>& append, int64_t dim) {
873 // Helper for diff that handles prepending and appending when at least one is present
874 TORCH_INTERNAL_ASSERT(prepend.has_value() || append.has_value(), "either prepend or append must be have value");
875 if (!prepend.has_value() && append.has_value()) {
876 return at::cat({self, append.value()}, dim);
877 } else if (prepend.has_value() && !append.has_value()) {
878 return at::cat({prepend.value(), self}, dim);
879 } else {
880 return at::cat({prepend.value(), self, append.value()}, dim);
881 }
882 }
883
diff_check_compatible_shape(const Tensor & self,const std::optional<Tensor> & other,int64_t dim)884 static inline void diff_check_compatible_shape(const Tensor& self, const std::optional<Tensor>&other, int64_t dim) {
885 // Helper for diff that checks whether the shape of the tensor to prepend or append
886 // is compatible with that of input
887 if (other.has_value()) {
888 int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim(), false);
889
890 TORCH_CHECK(
891 other.value().dim() == self.dim(),
892 "diff expects prepend or append to be the same dimension as input");
893
894 for (const auto i : c10::irange(other.value().dim())) {
895 if (i == wrapped_dim) {
896 continue;
897 }
898 TORCH_SYM_CHECK(
899 other.value().sym_size(i).sym_eq(self.sym_size(i)),
900 "diff expects the shape of tensor to prepend or append to match that of"
901 " input except along the differencing dimension;"
902 " input.size(", i, ") = ", self.sym_size(i), ", but got"
903 " tensor.size(", i, ") = ", other.value().sym_size(i));
904 }
905 }
906 }
907
diff_check(const Tensor & self,int64_t n,int64_t dim,const std::optional<Tensor> & prepend,const std::optional<Tensor> & append)908 static inline void diff_check(const Tensor& self, int64_t n, int64_t dim, const std::optional<Tensor>&prepend, const std::optional<Tensor>& append) {
909 // Helper for diff that checks whether its parameters are valid
910 TORCH_CHECK(
911 self.dim() >= 1,
912 "diff expects input to be at least one-dimensional");
913
914 TORCH_CHECK(
915 n >= 0,
916 "order must be non-negative but got ", n);
917
918 diff_check_compatible_shape(self, prepend, dim);
919 diff_check_compatible_shape(self, append, dim);
920 }
921
diff_helper(const Tensor & self,int64_t n,int64_t dim)922 static inline Tensor diff_helper(const Tensor& self, int64_t n, int64_t dim) {
923 if (n == 0) {
924 auto result = at::zeros_like(self);
925 result.copy_(self);
926 return result;
927 }
928
929 auto out_len = self.sym_size(dim) - 1;
930 auto result = self;
931 bool is_kBool = (self.dtype() == at::kBool);
932 n = n > self.sym_size(dim) ? self.sym_size(dim).guard_int(__FILE__, __LINE__) : n;
933
934 for (C10_UNUSED const auto i : c10::irange(n)) {
935 if (is_kBool) {
936 result = at::logical_xor(
937 at::narrow_symint(result, dim, 1, out_len),
938 at::narrow_symint(result, dim, 0, out_len)
939 );
940 } else {
941 result = at::narrow_symint(result, dim, 1, out_len) - at::narrow_symint(result, dim, 0, out_len);
942 }
943 out_len = out_len - 1;
944 }
945
946 return result;
947 }
948
diff(const Tensor & self,int64_t n,int64_t dim,const std::optional<Tensor> & prepend,const std::optional<Tensor> & append)949 Tensor diff(const Tensor& self, int64_t n, int64_t dim, const std::optional<Tensor>& prepend, const std::optional<Tensor>& append) {
950 diff_check(self, n, dim, prepend, append);
951 if ((!prepend.has_value() && !append.has_value()) || n == 0) {
952 return diff_helper(self, n, dim);
953 } else {
954 auto a = prepend_append_on_dim(self, prepend, append, dim);
955 return diff_helper(a, n, dim);
956 }
957 }
958
diff_out_helper(const Tensor & self,int64_t n,int64_t dim,Tensor & result)959 static inline Tensor& diff_out_helper(const Tensor& self, int64_t n, int64_t dim, Tensor& result) {
960 if (n == 0) {
961 if (resize_output_check_symint(result, self.sym_sizes())) {
962 result.resize__symint(self.sym_sizes());
963 }
964 check_scalar_type_device_layout_equal(result, self);
965 return result.copy_(self);
966 }
967
968 n = n > self.sym_size(dim) ? self.sym_size(dim).guard_int(__FILE__, __LINE__) : n;
969 const auto out_len = self.sym_size(dim) - n;
970 auto prev_result = self;
971
972 if (n > 1) {
973 prev_result = diff_helper(self, n - 1, dim);
974 }
975
976 if (self.dtype() == at::kBool) {
977 at::logical_xor_out(
978 result,
979 at::narrow_symint(prev_result, dim, 1, out_len),
980 at::narrow_symint(prev_result, dim, 0, out_len)
981 );
982 } else {
983 at::sub_out(
984 result,
985 at::narrow_symint(prev_result, dim, 1, out_len),
986 at::narrow_symint(prev_result, dim, 0, out_len)
987 );
988 }
989
990 return result;
991 }
992
diff_out(const Tensor & self,int64_t n,int64_t dim,const std::optional<Tensor> & prepend,const std::optional<Tensor> & append,Tensor & result)993 Tensor& diff_out(const Tensor& self, int64_t n, int64_t dim, const std::optional<Tensor>& prepend, const std::optional<Tensor>& append, Tensor& result) {
994 diff_check(self, n, dim, prepend, append);
995 if ((!prepend.has_value() && !append.has_value()) || n == 0) {
996 return diff_out_helper(self, n, dim, result);
997 } else {
998 auto a = prepend_append_on_dim(self, prepend, append, dim);
999 return diff_out_helper(a, n, dim, result);
1000 }
1001 }
1002
pre_check_gradient(const Tensor & self,std::optional<int64_t> spacing_size,at::OptionalIntArrayRef dim,int64_t edge_order)1003 static void pre_check_gradient(const Tensor& self, std::optional<int64_t> spacing_size, at::OptionalIntArrayRef dim, int64_t edge_order) {
1004 // Helper for gradient function to make sure input data satisfies prerequisites
1005 TORCH_CHECK(self.scalar_type() != ScalarType::Byte, "torch.gradient does not support uint8 input.");
1006 if (spacing_size.has_value() && !dim.has_value()) {
1007 // NOTE: If spacing was given as a scalar, the callers of this function
1008 // create a spacing vector of the expected size, and this check passes
1009 TORCH_CHECK(spacing_size.value() == self.dim(),
1010 "torch.gradient expected spacing to be unspecified, a scalar, or a list ",
1011 "of length equal to 'self.dim() = ", self.dim(), "', since dim argument ",
1012 "was not given, but got a list of length ", spacing_size.value());
1013 }
1014 if (spacing_size.has_value() && dim.has_value()) {
1015 TORCH_CHECK(spacing_size.value() == static_cast<int64_t>(dim.value().size()),
1016 "torch.gradient expected spacing to be unspecified, a scalar or it's spacing and dim arguments to have the same length, but got a spacing argument of length ", spacing_size.value(), " and a dim argument of length ", dim.value().size(), "." );
1017 }
1018 TORCH_CHECK(edge_order == 1 || edge_order == 2, "torch.gradient only supports edge_order=1 and edge_order=2.");
1019 if (dim.has_value()) {
1020 // The following function get called to check whether dim argument satisfies prerequisites.
1021 // The output of the function is not used for the computation of gradient.
1022 dim_list_to_bitset(dim.value(), self.dim());
1023 for (const auto i : c10::irange(dim.value().size())) {
1024 TORCH_CHECK(self.size(dim.value()[i]) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
1025 }
1026 } else {
1027 for (const auto i : c10::irange(self.dim())) {
1028 TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
1029 }
1030 }
1031 }
1032
gradient_helper(const Tensor & self,TensorList coordinates,IntArrayRef dim,int64_t edge_order)1033 static std::vector<Tensor> gradient_helper(const Tensor& self, TensorList coordinates, IntArrayRef dim, int64_t edge_order) {
1034 for (const auto i : c10::irange(coordinates.size())) {
1035 TORCH_CHECK(self.device() == coordinates[i].device(), "torch.gradient expected each tensor to be on the same device, but got devices ", self.device(), " and ", coordinates[i].device(), "!");
1036 }
1037
1038 std::vector<Tensor> result;
1039 for (const auto i : c10::irange(dim.size())) {
1040 TORCH_CHECK( coordinates[i].dim() == 1, "torch.gradient expected each element of spacing to have one dimension, but got an element with ", coordinates[i].dim(), " dimensions!");
1041 int64_t direction = maybe_wrap_dim(dim[i], self.dim());
1042 Tensor prepend, append;
1043 std::vector<int64_t> shape(self.dim(),1);
1044 shape[ direction ] = -1;
1045
1046 auto ax_dx = coordinates[i].diff(1,0);
1047 auto dx1 = at::slice(ax_dx, 0, 0, -1);
1048 auto dx2 = at::slice(ax_dx, 0, 1);
1049 auto a = ( -dx2 / (dx1*(dx1+dx2)) ).reshape(shape);
1050 auto b = ( (dx2-dx1) / (dx1*dx2) ).reshape(shape);
1051 auto c = ( dx1 / (dx2*(dx1+dx2)) ).reshape(shape);
1052
1053 auto center = a * at::slice(self, direction, 0, -2) + b * at::slice(self, direction , 1, -1) + c * at::slice(self, direction, 2);
1054 if (edge_order == 1) {
1055 prepend = (at::slice(self, direction, 1, 2 ) - at::slice(self, direction, 0, 1 )) / ax_dx[0] ;
1056 append = (at::slice(self, direction, -1 ) - at::slice(self, direction, -2, -1 )) / ax_dx[-1] ;
1057 } else if (edge_order == 2) {
1058 a =-(2.0 * ax_dx[0] + ax_dx[1]) / (ax_dx[0] * (ax_dx[0] + ax_dx[1])) ;
1059 b = ( ax_dx[0] + ax_dx[1]) / (ax_dx[0] * ax_dx[1]) ;
1060 c = ( -ax_dx[0] ) / (ax_dx[1] * (ax_dx[0] + ax_dx[1]));
1061 prepend = a * at::slice(self, direction, 0, 1) + b * at::slice(self, direction, 1, 2) + c * at::slice(self, direction, 2, 3);
1062
1063 a = ( ax_dx[-1] ) / (ax_dx[-2] * (ax_dx[-1] + ax_dx[-2]));
1064 b =-( ax_dx[-1] + ax_dx[-2]) / (ax_dx[-1] * ax_dx[-2]);
1065 c = (2 * ax_dx[-1] + ax_dx[-2]) / (ax_dx[-1] * (ax_dx[-1] + ax_dx[-2]));
1066 append = a * at::slice(self, direction, -3, -2) + b * at::slice(self, direction, -2, -1) + c * at::slice(self, direction, -1);
1067 }
1068
1069 result.emplace_back(prepend_append_on_dim(center, prepend, append, direction));
1070 }
1071 return result;
1072 }
1073
gradient_helper_float(const Tensor & self,ArrayRef<Scalar> spacing,IntArrayRef dim,int64_t edge_order)1074 static std::vector<Tensor> gradient_helper_float(const Tensor& self, ArrayRef<Scalar> spacing, IntArrayRef dim, int64_t edge_order) {
1075 std::vector<Tensor> result;
1076 for (const auto i : c10::irange(dim.size())) {
1077 int64_t direction = maybe_wrap_dim(dim[i], self.dim());
1078 const auto& ax_dx = spacing[i];
1079 Tensor prepend, append;
1080 auto center = (at::slice(self,direction, 2 ) - at::slice(self, direction, 0, -2 ) ) / ax_dx;
1081 if (edge_order==1) {
1082 prepend = (at::slice(self,direction, 1, 2) - at::slice(self, direction, 0, 1 ) ) / ax_dx;
1083 append = (at::slice(self,direction, -1 ) - at::slice(self, direction, -2, -1) ) / ax_dx ;
1084 } else if (edge_order==2) {
1085 prepend = (-1.5 * at::slice(self, direction, 0, 1) + 2 * at::slice(self, direction, 1, 2) - 0.5 * at::slice(self, direction, 2, 3))/ ax_dx;
1086 append = (0.5 * at::slice(self, direction, -3, -2) - 2 * at::slice(self, direction, -2, -1) + 1.5 * at::slice(self, direction, -1)) / ax_dx;
1087 }
1088
1089 result.emplace_back(prepend_append_on_dim(center/2, prepend, append, direction));
1090 }
1091 return result;
1092 }
1093
gradient_dim_preprocess(const Tensor & self,std::optional<int64_t> dim)1094 static std::vector<int64_t> gradient_dim_preprocess(const Tensor& self, std::optional<int64_t> dim) {
1095 // if gradient dim is provided as an integer, then we need to compute gradient only on this direction.
1096 // Moreover, if it's not provided at all, then we are interested in gradient for all directions.
1097 // Finally, if dim is provided as vector of ints, then it is not expected to be called by this function.
1098 if (dim.has_value()) {
1099 return std::vector<int64_t>{dim.value()};
1100 }
1101
1102 std::vector<int64_t> axis(self.dim());
1103 std::iota(axis.begin(), axis.end(), 0);
1104 return axis;
1105 }
1106
gradient(const Tensor & self,TensorList coordinates,IntArrayRef dim,int64_t edge_order)1107 std::vector<Tensor> gradient(const Tensor& self, TensorList coordinates, IntArrayRef dim, int64_t edge_order) {
1108 pre_check_gradient(self,
1109 std::optional<int64_t>(coordinates.size()),
1110 at::OptionalIntArrayRef(dim),
1111 edge_order);
1112 return gradient_helper(self, coordinates, dim, edge_order);
1113 }
1114
gradient(const Tensor & self,TensorList coordinates,std::optional<int64_t> dim,int64_t edge_order)1115 std::vector<Tensor> gradient(const Tensor& self, TensorList coordinates, std::optional<int64_t> dim, int64_t edge_order) {
1116 const auto processed_dim = gradient_dim_preprocess(self, dim);
1117 pre_check_gradient(self,
1118 std::optional<int64_t>(coordinates.size()),
1119 dim.has_value() ? at::OptionalIntArrayRef(processed_dim) : std::nullopt,
1120 edge_order);
1121 return gradient_helper(self, coordinates, processed_dim, edge_order);
1122 }
1123
gradient(const Tensor & self,c10::ArrayRef<Scalar> spacing,IntArrayRef dim,int64_t edge_order)1124 std::vector<Tensor> gradient(const Tensor& self, c10::ArrayRef<Scalar> spacing, IntArrayRef dim, int64_t edge_order) {
1125 pre_check_gradient(self,
1126 std::optional<int64_t>(spacing.size()),
1127 at::OptionalIntArrayRef(dim),
1128 edge_order);
1129 return gradient_helper_float(self, spacing, dim, edge_order);
1130 }
1131
gradient(const Tensor & self,ArrayRef<Scalar> spacing,std::optional<int64_t> dim,int64_t edge_order)1132 std::vector<Tensor> gradient(const Tensor& self, ArrayRef<Scalar> spacing, std::optional<int64_t> dim, int64_t edge_order) {
1133 const auto processed_dim = gradient_dim_preprocess(self, dim);
1134 pre_check_gradient(self,
1135 std::optional<int64_t>(spacing.size()),
1136 dim.has_value() ? at::OptionalIntArrayRef(processed_dim) : std::nullopt,
1137 edge_order);
1138 return gradient_helper_float(self, spacing, processed_dim, edge_order);
1139 }
1140
gradient(const Tensor & self,const Scalar & unit_size,IntArrayRef dim,int64_t edge_order)1141 std::vector<Tensor> gradient(const Tensor& self, const Scalar& unit_size, IntArrayRef dim, int64_t edge_order) {
1142 // When spacing is given as scalar, while dim is given as IntArrayRef, scalar value need to
1143 // be taken as unit size at every given dimension element of - dim.
1144 std::vector<Scalar> spacing(dim.size(), unit_size);
1145 pre_check_gradient(self,
1146 std::optional<int64_t>(spacing.size()),
1147 at::OptionalIntArrayRef(dim),
1148 edge_order);
1149 return gradient_helper_float(self, spacing, dim, edge_order);
1150 }
1151
gradient(const Tensor & self,const std::optional<Scalar> & unit_size,std::optional<int64_t> dim,int64_t edge_order)1152 std::vector<Tensor> gradient(const Tensor& self, const std::optional<Scalar>& unit_size, std::optional<int64_t> dim, int64_t edge_order) {
1153 const auto processed_dim = gradient_dim_preprocess(self, dim);
1154 // When unit_size not provided, it is always assumed to be equal to 1.
1155 // When dim has integer value it implies we are looking for gradient in the specific direction, however when
1156 // it is not provided, it means we are interested to find gradient in all directions.
1157 std::vector<Scalar> spacing(dim.has_value() ? 1 : self.dim(),
1158 unit_size.has_value() ? unit_size.value() : 1.0) ;
1159 pre_check_gradient(self,
1160 unit_size.has_value() ? std::optional<int64_t>(spacing.size()) : std::nullopt,
1161 dim.has_value() ? at::OptionalIntArrayRef(processed_dim) : std::nullopt,
1162 edge_order);
1163 return gradient_helper_float(self, spacing, processed_dim, edge_order);
1164 }
1165
gradient(const Tensor & self,IntArrayRef dim,int64_t edge_order)1166 std::vector<Tensor> gradient(const Tensor& self, IntArrayRef dim, int64_t edge_order) {
1167 std::vector<Scalar> spacing(dim.size(), 1.0) ;
1168 pre_check_gradient(self,
1169 std::optional<int64_t>(spacing.size()),
1170 at::OptionalIntArrayRef(dim),
1171 edge_order);
1172 return gradient_helper_float(self, spacing, dim, edge_order);
1173 }
1174
1175 // ALL REDUCE #################################################################
1176
should_use_acc_buffer(at::TensorIterator & iter)1177 inline bool should_use_acc_buffer(at::TensorIterator& iter) {
1178 const auto ndim = iter.ndim();
1179 if (!iter.device().is_cpu() || iter.noutputs() != 1) {
1180 return false;
1181 }
1182 if (!at::isReducedFloatingType(iter.common_dtype())) {
1183 return false;
1184 }
1185 if (ndim < 2) {
1186 return false;
1187 }
1188 auto out_strides = iter.strides(0);
1189 for (const auto dim : c10::irange(0, 2)) {
1190 if (out_strides[dim] != 0) {
1191 return false;
1192 }
1193 }
1194 return true;
1195 }
1196
TORCH_IMPL_FUNC(sum_out)1197 TORCH_IMPL_FUNC(sum_out)
1198 (const Tensor& self,
1199 OptionalIntArrayRef opt_dim,
1200 bool keepdim,
1201 std::optional<ScalarType> opt_dtype,
1202 const Tensor& result) {
1203 auto iter = meta::make_reduction_from_out_ty(self, result, opt_dim, keepdim, result.scalar_type());
1204 if (iter.numel() == 0) {
1205 result.zero_();
1206 } else {
1207 // Here is a limitation of TensorIterator reductions for permuted input with lower precision on CPU.
1208 // Consider the case: TensorIterator coalesces such input and output to >= 2 dims tensors,
1209 // and the output stride is [0, 0, x, x, ...] with x >= 0 (two reduced dimensions and non-reduced dims).
1210 // Since the reduction loop only operates on two dimensions at a time,
1211 // the intermediate sums is forced to do accumulation in the second reduced dim with lower precision.
1212 // See https://github.com/pytorch/pytorch/issues/83149
1213 if (should_use_acc_buffer(iter)) {
1214 auto tmp_output = at::empty(result.sizes(), result.options().dtype(kFloat));
1215 at::sum_outf(self.to(ScalarType::Float), opt_dim, keepdim, /*dtype=*/std::nullopt, tmp_output);
1216 result.copy_(tmp_output);
1217 } else{
1218 sum_stub(iter.device_type(), iter);
1219 }
1220 }
1221 }
1222
sum(const Tensor & self,std::optional<ScalarType> dtype)1223 Tensor sum(const Tensor &self, std::optional<ScalarType> dtype) {
1224 return at::sum(self, IntArrayRef{}, false, dtype);
1225 }
1226
sum(const Tensor & self,DimnameList dim,bool keepdim,std::optional<ScalarType> dtype)1227 Tensor sum(const Tensor& self, DimnameList dim, bool keepdim, std::optional<ScalarType> dtype) {
1228 return at::sum(self, dimnames_to_positions(self, dim), keepdim, dtype);
1229 }
1230
sum_out(const Tensor & self,DimnameList dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1231 Tensor& sum_out(const Tensor& self, DimnameList dim,
1232 bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
1233 return at::sum_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
1234 }
1235
nansum_out(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1236 Tensor& nansum_out(const Tensor& self, at::OptionalIntArrayRef dim,
1237 bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
1238 if (self.device().is_cpu()) {
1239 TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "nansum does not support complex inputs");
1240 }
1241
1242 // For integral types, use existing sum as
1243 // integral types don't have `Nan`.
1244 if (c10::isIntegralType(self.scalar_type(), true)){
1245 return at::sum_out(result, self, dim, keepdim, opt_dtype);
1246 }
1247
1248 ScalarType dtype = get_dtype_from_result(result, opt_dtype);
1249 auto iter = make_reduction("nansum", result, self, dim, keepdim, dtype);
1250 if (iter.numel() == 0) {
1251 result = result.zero_();
1252 } else {
1253 nansum_stub(iter.device_type(), iter);
1254 }
1255 return result;
1256 }
1257
nansum(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype)1258 Tensor nansum(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
1259 ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
1260 Tensor result = create_reduction_result(self, dim, keepdim, dtype);
1261 return at::native::nansum_out(self, dim, keepdim, dtype, result);
1262 }
1263
1264 namespace {
1265 template<typename scalar_t, typename accscalar_t = at::acc_type<scalar_t, false>>
set_result(Tensor & result,accscalar_t sum)1266 void inline set_result(Tensor& result, accscalar_t sum)
1267 {
1268 if constexpr (std::is_integral_v<accscalar_t>) {
1269 // all integer types get promoted to kLong
1270 *result.data_ptr<int64_t>() = sum;
1271 } else {
1272 *result.data_ptr<scalar_t>() = sum;
1273 }
1274 }
1275 }
1276 // NOTE: this could be implemented via diag and sum, but this has perf problems,
1277 // see https://github.com/pytorch/pytorch/pull/47305,
trace_cpu(const Tensor & self)1278 Tensor trace_cpu(const Tensor& self) {
1279 Tensor result;
1280 // Returns the ScalarType of the self tensor if the tensor is non integral type
1281 // In the case, self is an integer type tensor, at::kLong is return since promote_integers
1282 // is set to true
1283 ScalarType dtype = get_dtype_from_self(self, std::nullopt, true);
1284 result = at::empty({}, self.options().dtype(dtype));
1285 AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "trace", [&] {
1286 using accscalar_t = at::acc_type<scalar_t, false>;
1287 accscalar_t sum = 0;
1288 const auto* t_data = self.const_data_ptr<scalar_t>();
1289
1290 int64_t t_stride_0, t_stride_1, t_diag_size;
1291
1292 TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim());
1293
1294 t_stride_0 = self.stride(0);
1295 t_stride_1 = self.stride(1);
1296
1297 t_diag_size = std::min(self.size(0), self.size(1));
1298 for (const auto i : c10::irange(t_diag_size)) {
1299 sum += t_data[i * (t_stride_0 + t_stride_1)];
1300 }
1301 set_result<scalar_t>(result, sum);
1302
1303 });
1304
1305 return result;
1306 }
1307
impl_func_prod(const Tensor & self,IntArrayRef dims,bool keepdim,std::optional<ScalarType> dtype,const Tensor & result)1308 static void impl_func_prod(
1309 const Tensor& self,
1310 IntArrayRef dims,
1311 bool keepdim,
1312 std::optional<ScalarType> dtype,
1313 const Tensor& result) {
1314 auto iter = meta::make_reduction_from_out_ty(self, result, dims, keepdim, result.scalar_type());
1315 if (iter.numel() == 0) {
1316 result.fill_(1);
1317 } else {
1318 prod_stub(iter.device_type(), iter);
1319 }
1320 }
1321
TORCH_IMPL_FUNC(prod_out)1322 TORCH_IMPL_FUNC(prod_out)
1323 (const Tensor& self,
1324 int64_t dim,
1325 bool keepdim,
1326 std::optional<ScalarType> dtype,
1327 const Tensor& result) {
1328 impl_func_prod(self, dim, keepdim, dtype, result);
1329 }
1330
prod(const Tensor & self,std::optional<ScalarType> opt_dtype)1331 Tensor prod(const Tensor &self, std::optional<ScalarType> opt_dtype) {
1332 auto dtype = get_dtype_from_self(self, opt_dtype, true);
1333 auto shape = meta::get_reduction_shape(self, {}, false);
1334 Tensor result = at::empty(shape, self.options().dtype(dtype));
1335 impl_func_prod(self, {}, false, dtype, result);
1336 return result;
1337 }
1338
prod(const Tensor & self,Dimname dim,bool keepdim,std::optional<ScalarType> dtype)1339 Tensor prod(const Tensor& self, Dimname dim, bool keepdim, std::optional<ScalarType> dtype) {
1340 return at::prod(self, dimname_to_position(self, dim), keepdim, dtype);
1341 }
1342
prod_out(const Tensor & self,Dimname dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1343 Tensor& prod_out(const Tensor& self, Dimname dim,
1344 bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
1345 return at::prod_out(result, self, dimname_to_position(self, dim), keepdim, opt_dtype);
1346 }
1347
TORCH_IMPL_FUNC(mean_out)1348 TORCH_IMPL_FUNC(mean_out)
1349 (const Tensor& self,
1350 OptionalIntArrayRef opt_dim,
1351 bool keepdim,
1352 std::optional<ScalarType> opt_dtype,
1353 const Tensor& result) {
1354 ScalarType dtype = result.scalar_type();
1355 // TODO: the TensorIterator reduction implementation of mean
1356 // (mean_kernel_impl()) is unvectorized and leads to very poor performance
1357 // for production workloads. Once that's fixed, the following code can be used
1358 // in lieu of the sum + divide implementation below.
1359 if (self.device().is_cpu()) {
1360 int64_t dim_prod = 1;
1361 if (!opt_dim.has_value() || opt_dim.value().empty() || self.ndimension() == 0) {
1362 dim_prod = self.numel();
1363 } else {
1364 auto dim = opt_dim.value();
1365 for (auto d : dim) {
1366 dim_prod *= self.size(d);
1367 }
1368 }
1369 auto& result_mut = const_cast<Tensor&>(result);
1370 // For accuracy reasons, BF16/FP16 mean should be computed via the
1371 // following approach:
1372 // cast_fp32 -> sum -> div -> cast_bf16_or_fp16
1373 //
1374 // Such an approach is necessary because if we were to choose the same
1375 // approach for BF16/FP16 as FP32 here, then it would have resulted in
1376 // the following code-flow -
1377 // cast_fp32 -> sum -> cast_bf16 -> cast_fp32 -> div -> cast_bf16,
1378 // which, in turn, does not produce as accurate results.
1379 bool is_half_type = (dtype == kHalf || dtype == kBFloat16);
1380 auto sum_out_dtype = is_half_type ? ScalarType::Float : dtype;
1381 result_mut = is_half_type ? result_mut.to(sum_out_dtype) : result_mut;
1382 // If dtype is FP16 or BF16, self (input tensor) will initially be cast to
1383 // FP32 in sum_out. This results in having to read that FP32 tensor again,
1384 // but maybe in the future, we could revise the implementation to not
1385 // materialize that intermediate FP32 tensor. That approach would probably
1386 // require some modifications in binary_kernel_reduce_vec(),
1387 // TensorIteratorBase::for_each(), and
1388 // TensorIteratorBase::serial_for_each(), apart from sum kernel for CPU.
1389 at::sum_out(result_mut, self, opt_dim, keepdim, sum_out_dtype).div_(dim_prod);
1390 // After sum & div, cast result_mut back to BF16 or FP16, if required.
1391 result_mut = is_half_type ? result_mut.to(dtype) : result_mut;
1392 } else {
1393 // device is not CPU
1394 auto iter = at::meta::make_reduction_from_out_ty(
1395 self, result, opt_dim, keepdim, dtype);
1396 if (iter.numel() == 0) {
1397 result.fill_(std::numeric_limits<double>::quiet_NaN());
1398 } else {
1399 mean_stub(iter.device_type(), iter);
1400 }
1401 }
1402 }
1403
mean(const Tensor & self,std::optional<ScalarType> dtype)1404 Tensor mean(const Tensor &self, std::optional<ScalarType> dtype) {
1405 return at::mean(self, IntArrayRef{}, false, dtype);
1406 }
1407
mean(const Tensor & self,DimnameList dim,bool keepdim,std::optional<ScalarType> dtype)1408 Tensor mean(const Tensor& self, DimnameList dim, bool keepdim, std::optional<ScalarType> dtype) {
1409 return at::mean(self, dimnames_to_positions(self, dim), keepdim, dtype);
1410 }
1411
mean_out(const Tensor & self,DimnameList dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1412 Tensor& mean_out(const Tensor& self, DimnameList dim,
1413 bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
1414 return at::mean_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
1415 }
1416
mean_dtype_out(const Tensor & self,std::optional<ScalarType> dtype,Tensor & result)1417 Tensor& mean_dtype_out(const Tensor &self, std::optional<ScalarType> dtype, Tensor& result) {
1418 TORCH_CHECK(
1419 canCast(self.scalar_type(), result.scalar_type()),
1420 "mean.dtype_out(): input types can't be cast to the desired output type ",
1421 result.scalar_type());
1422 // at::mean_out should make sure dtype and result.scalar_type() are the same
1423 return at::mean_out(result, self, IntArrayRef{}, false, dtype);
1424 }
1425
1426 // TODO(@heitorschueroff) implement custom kernels for nanmean
nanmean_out(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)1427 Tensor& nanmean_out(
1428 const Tensor& self,
1429 at::OptionalIntArrayRef dim,
1430 bool keepdim,
1431 std::optional<ScalarType> opt_dtype,
1432 Tensor& result) {
1433 TORCH_CHECK(
1434 self.is_floating_point() || self.is_complex(),
1435 "nanmean(): expected input to have floating point or complex dtype but got ",
1436 self.scalar_type());
1437 const auto factor = at::native::isnan(self).logical_not_().sum(dim, keepdim);
1438 at::native::nansum_out(self, dim, keepdim, opt_dtype, result).div_(factor);
1439 return result;
1440 }
1441
nanmean(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype)1442 Tensor nanmean(
1443 const Tensor& self,
1444 at::OptionalIntArrayRef dim,
1445 bool keepdim,
1446 std::optional<ScalarType> opt_dtype) {
1447 TORCH_CHECK(
1448 self.is_floating_point() || self.is_complex(),
1449 "nanmean(): expected input to have floating point or complex dtype but got ",
1450 self.scalar_type());
1451 const auto factor =
1452 at::native::isnan(self.detach()).logical_not_().sum(dim, keepdim);
1453 return at::nansum(self, dim, keepdim, opt_dtype).div(factor);
1454 }
1455
logsumexp_out_impl(Tensor & result,const Tensor & self,IntArrayRef dims,bool keepdim)1456 static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
1457 // can't take max of empty tensor
1458 if (self.numel() != 0) {
1459 // For complex numbers, use the real part to calculate the max. Based on
1460 // https://scicomp.stackexchange.com/questions/34273/log-sum-exp-trick-for-signed-complex-numbers
1461 auto maxes = at::amax(at::real(self), dims, true);
1462 auto maxes_squeezed = (keepdim ? maxes : at::squeeze(maxes, dims));
1463 maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
1464 at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
1465 result.log_().add_(maxes_squeezed);
1466 } else {
1467 at::sum_out(result, at::exp(self), dims, keepdim);
1468 result.log_();
1469 }
1470 return result;
1471 }
1472
logsumexp_out(const Tensor & self,IntArrayRef dims,bool keepdim,Tensor & result)1473 Tensor& logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
1474 // Complex type implies floating point type
1475 TORCH_CHECK(at::isFloatingType(result.scalar_type()) || at::isComplexType(result.scalar_type()),
1476 "logsumexp(): Expected floating point type for result tensor, but got: ",
1477 result.scalar_type());
1478 {
1479 NoNamesGuard guard;
1480 if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
1481 // for integral inputs, promote input to default floating type.
1482 auto default_dtype = at::typeMetaToScalarType(c10::get_default_dtype());
1483 logsumexp_out_impl(result, self.to(default_dtype), dims, keepdim);
1484 } else {
1485 logsumexp_out_impl(result, self, dims, keepdim);
1486 }
1487 }
1488 namedinference::propagate_names_for_reduction(result, self, dims, keepdim);
1489 return result;
1490 }
1491
logsumexp(const Tensor & self,IntArrayRef dims,bool keepdim)1492 Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {
1493 TensorOptions result_options;
1494 if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
1495 // even for integral inputs, result is floating dtype
1496 auto default_dtype = at::typeMetaToScalarType(c10::get_default_dtype());
1497 result_options = self.options().dtype(default_dtype);
1498 } else {
1499 result_options = self.options();
1500 }
1501 auto result = at::empty({0}, result_options);
1502 return at::logsumexp_outf(self, dims, keepdim, result);
1503 }
1504
logsumexp(const Tensor & self,DimnameList dims,bool keepdim)1505 Tensor logsumexp(const Tensor& self, DimnameList dims, bool keepdim) {
1506 return at::logsumexp(self, dimnames_to_positions(self, dims), keepdim);
1507 }
1508
logsumexp_out(const Tensor & self,DimnameList dims,bool keepdim,Tensor & result)1509 Tensor& logsumexp_out(const Tensor& self, DimnameList dims, bool keepdim, Tensor& result) {
1510 return at::logsumexp_out(result, self, dimnames_to_positions(self, dims), keepdim);
1511 }
1512
1513 // special_logsumexp, alias for logsumexp
special_logsumexp(const Tensor & self,IntArrayRef dims,bool keepdim)1514 Tensor special_logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {
1515 return self.logsumexp(dims, keepdim);
1516 }
special_logsumexp_out(const Tensor & self,IntArrayRef dims,bool keepdim,Tensor & result)1517 Tensor& special_logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
1518 return at::logsumexp_out(result, self, dims, keepdim);
1519 }
1520
impl_func_norm(const Tensor & self,const OptionalScalarRef & opt_p,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,const Tensor & result)1521 static void impl_func_norm(
1522 const Tensor& self,
1523 const OptionalScalarRef& opt_p,
1524 IntArrayRef dim,
1525 bool keepdim,
1526 std::optional<ScalarType> opt_dtype,
1527 const Tensor& result) {
1528 // Left this implementation without deprecating it as it is called in a number of places
1529 // in the codebase. We should swap those by linalg_vector_norm
1530 auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to<double>();
1531 at::linalg_vector_norm_out(const_cast<Tensor&>(result), self, p, dim, keepdim, opt_dtype);
1532 }
1533
TORCH_IMPL_FUNC(norm_out)1534 TORCH_IMPL_FUNC(norm_out)
1535 (const Tensor& self,
1536 const OptionalScalarRef p,
1537 IntArrayRef dim,
1538 bool keepdim,
1539 const Tensor& result) {
1540 impl_func_norm(self, p, dim, keepdim, std::nullopt, result);
1541 }
1542
TORCH_IMPL_FUNC(norm_dtype_out)1543 TORCH_IMPL_FUNC(norm_dtype_out)
1544 (const Tensor& self,
1545 const OptionalScalarRef p,
1546 IntArrayRef dim,
1547 bool keepdim,
1548 ScalarType dtype,
1549 const Tensor& result) {
1550 impl_func_norm(self, p, dim, keepdim, dtype, result);
1551 }
1552
sparse_norm(const Tensor & self,const std::optional<Scalar> & p,IntArrayRef dim,bool keepdim)1553 Tensor sparse_norm(
1554 const Tensor& self,
1555 const std::optional<Scalar>& p,
1556 IntArrayRef dim,
1557 bool keepdim) {
1558 return at::native_norm(self, p, dim, keepdim, std::nullopt);
1559 }
1560
sparse_dtype_norm(const Tensor & self,const std::optional<Scalar> & p,IntArrayRef dim,bool keepdim,ScalarType dtype)1561 Tensor sparse_dtype_norm(
1562 const Tensor& self,
1563 const std::optional<Scalar>& p,
1564 IntArrayRef dim,
1565 bool keepdim,
1566 ScalarType dtype) {
1567 return at::native_norm(self, p, dim, keepdim, dtype);
1568 }
1569
norm(const Tensor & self,const std::optional<Scalar> & p,ScalarType dtype)1570 Tensor norm(const Tensor& self, const std::optional<Scalar>& p, ScalarType dtype) {
1571 return at::norm(self, p, IntArrayRef{}, false, dtype);
1572 }
1573
norm(const Tensor & self,const Scalar & p)1574 Tensor norm(const Tensor& self, const Scalar& p) {
1575 return at::norm(self, p, IntArrayRef{}, false);
1576 }
1577
get_allany_iter(const Tensor & self,const Tensor & result,OptionalIntArrayRef dims,bool keepdim)1578 inline TensorIterator get_allany_iter(
1579 const Tensor& self,
1580 const Tensor& result,
1581 OptionalIntArrayRef dims,
1582 bool keepdim) {
1583 if (self.is_cuda()) {
1584 // As CUDA supports dynamic type casting, we use this overload of
1585 // `make_reduction`, which doesn't cast input to the result type i.e. kBool.,
1586 // otherwise we use the overload below which casts the input to kBool (which is
1587 // an extra operation).
1588 return meta::make_reduction(self, result, dims, keepdim, self.scalar_type());
1589 }
1590 return meta::make_reduction_from_out_ty(
1591 self, result, dims, keepdim, result.scalar_type());
1592 }
1593
1594 template <int identity, typename Stub>
allany_impl(const Tensor & self,const Tensor & result,OptionalIntArrayRef dims,bool keepdim,Stub & stub)1595 inline void allany_impl(
1596 const Tensor& self,
1597 const Tensor& result,
1598 OptionalIntArrayRef dims,
1599 bool keepdim,
1600 Stub& stub) {
1601 if (self.numel() == 0) {
1602 result.fill_(identity);
1603 } else if (self.numel() == 1) {
1604 result.copy_(self.view_as(result).to(at::kBool));
1605 } else {
1606 auto iter = get_allany_iter(self, result, dims, keepdim);
1607 stub(iter.device_type(), iter);
1608 }
1609 }
1610
TORCH_IMPL_FUNC(all_out)1611 TORCH_IMPL_FUNC(all_out)
1612 (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
1613 allany_impl<1>(self, result, dim, keepdim, and_stub);
1614 }
1615
TORCH_IMPL_FUNC(all_dims_out)1616 TORCH_IMPL_FUNC(all_dims_out)
1617 (const Tensor& self, OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
1618 allany_impl<1>(self, result, dim, keepdim, and_stub);
1619 }
1620
TORCH_IMPL_FUNC(all_all_out)1621 TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
1622 allany_impl<1>(self, result, {}, false, and_stub);
1623 }
1624
TORCH_IMPL_FUNC(any_out)1625 TORCH_IMPL_FUNC(any_out)
1626 (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
1627 allany_impl<0>(self, result, dim, keepdim, or_stub);
1628 }
1629
TORCH_IMPL_FUNC(any_dims_out)1630 TORCH_IMPL_FUNC(any_dims_out)
1631 (const Tensor& self, OptionalIntArrayRef dim, bool keepdim, const Tensor& result) {
1632 allany_impl<0>(self, result, dim, keepdim, or_stub);
1633 }
1634
TORCH_IMPL_FUNC(any_all_out)1635 TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
1636 allany_impl<0>(self, result, {}, false, or_stub);
1637 }
1638
1639 template <bool is_all>
allany_dims_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim)1640 Tensor allany_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
1641 // Default implementation in terms of all-reduce or single dim reduce
1642 if (!dim) {
1643 Tensor out;
1644 if constexpr (is_all) {
1645 out = self.all();
1646 } else {
1647 out = self.any();
1648 }
1649
1650 if (keepdim) {
1651 DimVector out_shape(self.dim(), 1);
1652 return out.expand(out_shape);
1653 }
1654 return out;
1655 }
1656
1657 if (dim->empty()) {
1658 if (self.scalar_type() == kByte) {
1659 // Convert to a 1 or 0 mask
1660 auto out = at::empty_like(self);
1661 return at::ne_outf(self, 0, out);
1662 } else {
1663 return at::_to_copy(self, kBool);
1664 }
1665 }
1666
1667 Tensor out = self;
1668 for (auto d : *dim) {
1669 if constexpr (is_all) {
1670 out = out.all(d, /*keepdim=*/true);
1671 } else {
1672 out = out.any(d, /*keepdim=*/true);
1673 }
1674 }
1675 return keepdim ? out : out.squeeze(*dim);
1676 }
1677
all_dims_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim)1678 Tensor all_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
1679 return allany_dims_default<true>(self, dim, keepdim);
1680 }
1681
any_dims_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim)1682 Tensor any_dims_default(const Tensor &self, OptionalIntArrayRef dim, bool keepdim) {
1683 return allany_dims_default<false>(self, dim, keepdim);
1684 }
1685
all_dims_out_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim,Tensor & result)1686 Tensor& all_dims_out_default(
1687 const Tensor &self, OptionalIntArrayRef dim, bool keepdim, Tensor &result) {
1688 TORCH_CHECK(self.device() == result.device(), "all.dims: output must be on the same device as input");
1689 auto tmp = all_dims_default(self, dim, keepdim);
1690 at::native::resize_output(result, tmp.sizes());
1691 return result.copy_(tmp);
1692 }
1693
any_dims_out_default(const Tensor & self,OptionalIntArrayRef dim,bool keepdim,Tensor & result)1694 Tensor& any_dims_out_default(
1695 const Tensor &self, OptionalIntArrayRef dim, bool keepdim, Tensor &result) {
1696 TORCH_CHECK(self.device() == result.device(), "any.dims: output must be on the same device as input");
1697 auto tmp = any_dims_default(self, dim, keepdim);
1698 at::native::resize_output(result, tmp.sizes());
1699 return result.copy_(tmp);
1700 }
1701
TORCH_IMPL_FUNC(amin_out)1702 TORCH_IMPL_FUNC(amin_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
1703 auto iter =
1704 meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
1705 if (iter.numel() != 0) {
1706 min_values_stub(iter.device_type(), iter);
1707 }
1708 }
1709
TORCH_IMPL_FUNC(amax_out)1710 TORCH_IMPL_FUNC(amax_out) (const Tensor& self, IntArrayRef dim, bool keepdim, const Tensor& result) {
1711 auto iter =
1712 meta::make_reduction(self, result, dim, keepdim, self.scalar_type());
1713 if (iter.numel() != 0) {
1714 max_values_stub(iter.device_type(), iter);
1715 }
1716 }
1717
1718 template <class Stub>
argmax_argmin_impl(const Tensor & self,std::optional<int64_t> dim,bool keepdim,const Tensor & result,Stub & stub)1719 void argmax_argmin_impl(
1720 const Tensor& self,
1721 std::optional<int64_t> dim,
1722 bool keepdim,
1723 const Tensor& result,
1724 Stub& stub) {
1725 c10::MaybeOwned<Tensor> in;
1726 DimVector dims;
1727 int64_t _dim = 0;
1728
1729 if (dim.has_value()) {
1730 _dim = maybe_wrap_dim(dim.value(), self.dim());
1731 auto sizes = self.sizes();
1732
1733 if (sizes[_dim] == 1) {
1734 result.fill_(0);
1735 return;
1736 }
1737
1738 dims = IntArrayRef(_dim);
1739 in = c10::MaybeOwned<Tensor>::borrowed(self);
1740 } else {
1741 in = c10::MaybeOwned<Tensor>::owned(self.reshape({-1}));
1742 keepdim = false;
1743 }
1744
1745 auto iter =
1746 meta::make_reduction(*in, result, dims, keepdim, self.scalar_type());
1747
1748 if (iter.numel() != 0) {
1749 stub(iter.device_type(), iter);
1750 }
1751 }
1752
TORCH_IMPL_FUNC(argmax_out)1753 TORCH_IMPL_FUNC(argmax_out)
1754 (const Tensor& self,
1755 std::optional<int64_t> dim,
1756 bool keepdim,
1757 const Tensor& result) {
1758 argmax_argmin_impl(self, dim, keepdim, result, argmax_stub);
1759 }
1760
TORCH_IMPL_FUNC(argmin_out)1761 TORCH_IMPL_FUNC(argmin_out)
1762 (const Tensor& self,
1763 std::optional<int64_t> dim,
1764 bool keepdim,
1765 const Tensor& result) {
1766 argmax_argmin_impl(self, dim, keepdim, result, argmin_stub);
1767 }
1768
std_var_all_cpu(const Tensor & self,double correction,bool take_sqrt)1769 static double std_var_all_cpu(const Tensor& self, double correction, bool take_sqrt) {
1770 const auto dtype = self.scalar_type();
1771 TORCH_CHECK(dtype == kDouble || dtype == kFloat,
1772 "std_var_all: Unsupported dtype ", dtype);
1773
1774 auto mean = self.mean().item<double>();
1775 auto iter = TensorIteratorConfig()
1776 .add_const_input(self)
1777 .build();
1778
1779 auto reduction = [&](int64_t begin, int64_t end, double thread_sum) {
1780 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "std_var_all_cpu", [&] {
1781 iter.serial_for_each([&] (char** data, const int64_t* strides, int64_t size0, int64_t size1) {
1782 const double local_mean = mean;
1783 const int64_t inner_stride = strides[0];
1784 const int64_t outer_stride = strides[1];
1785
1786 double local_sum = 0.0;
1787 for (const auto i : c10::irange(size1)) {
1788 const char* row_ptr = data[0] + outer_stride * i;
1789 for (const auto j : c10::irange(size0)) {
1790 const auto ptr = reinterpret_cast<const scalar_t*>(row_ptr + inner_stride * j);
1791 auto dx = (static_cast<double>(*ptr) - local_mean);
1792 local_sum += dx * dx;
1793 }
1794 }
1795 thread_sum += local_sum;
1796 }, {begin, end});
1797 });
1798
1799 return thread_sum;
1800 };
1801
1802 // ((x - mean)**2).sum()
1803 const double sum_dx2 = at::parallel_reduce(
1804 0, iter.numel(), at::internal::GRAIN_SIZE, 0.0, reduction, std::plus<>{});
1805
1806 const auto var = [&] () __ubsan_ignore_float_divide_by_zero__ {
1807 return sum_dx2 / std::max(0.0, self.numel() - correction);
1808 }();
1809 const auto result = take_sqrt ? std::sqrt(var) : var;
1810
1811 if (dtype == kFloat) {
1812 // Convert to infinity if out of range for a float.
1813 // Doing it now prevents checked_convert failing later
1814 return static_cast<float>(result);
1815 }
1816 return result;
1817 }
1818
1819 namespace {
warn_invalid_degrees_of_freedom(const char * fname,const TensorIterator & iter,double correction)1820 inline void warn_invalid_degrees_of_freedom(const char* fname, const TensorIterator& iter, double correction) {
1821 int64_t reducing_over_num_elements = iter.num_output_elements() == 0 ? 0 : iter.numel() / iter.num_output_elements();
1822 if (reducing_over_num_elements - correction <= 0) {
1823 TORCH_WARN(fname, "(): degrees of freedom is <= 0. Correction should be strictly less than the reduction factor (input numel divided by output numel).");
1824 }
1825 }
1826 } // namespace
1827
std_var_out(const char * fname,Tensor & result,const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction_opt,bool keepdim,bool take_sqrt)1828 static Tensor& std_var_out(
1829 const char* fname, Tensor& result, const Tensor& self,
1830 at::OptionalIntArrayRef dim, const std::optional<Scalar>& correction_opt,
1831 bool keepdim, bool take_sqrt) {
1832 TORCH_CHECK(self.device().is_cpu() || self.device().is_cuda() || self.device().is_xpu(),
1833 "std and var supports tensors on a CPU, CUDA, or XPU device only, but got: ",
1834 self.device().type());
1835 TORCH_CHECK(self.layout() == Layout::Strided,
1836 "std and var only supports strided layout, got: ", self.layout());
1837 TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
1838 "std and var only support floating point and complex dtypes");
1839
1840 if (at::isComplexType(self.scalar_type())) {
1841 // For complex, calculate variance of real and imaginary components
1842 // separately then add to get overall variance.
1843 ScalarType dtype = c10::toRealValueType(get_dtype_from_result(result, {}));
1844 Tensor real_in = at::real(self);
1845 Tensor real_out = at::empty({0}, self.options().dtype(dtype));
1846 std_var_out(
1847 fname,
1848 real_out,
1849 real_in,
1850 dim,
1851 correction_opt,
1852 keepdim,
1853 /*take_sqrt=*/false);
1854
1855 Tensor imag_in = at::imag(self);
1856 Tensor imag_out = at::empty({0}, self.options().dtype(dtype));
1857 std_var_out(
1858 fname,
1859 imag_out,
1860 imag_in,
1861 dim,
1862 correction_opt,
1863 keepdim,
1864 /*take_sqrt=*/false);
1865
1866 at::add_out(result, real_out, imag_out);
1867 if (take_sqrt) {
1868 at::sqrt_out(result, result);
1869 }
1870 return result;
1871 }
1872
1873 // Computation for floating point
1874 const auto correction = correction_opt.value_or(1).toDouble();
1875 ScalarType dtype = get_dtype_from_result(result, {});
1876 auto iter = make_reduction(fname, result, self, dim, keepdim, dtype);
1877 TORCH_CHECK(at::canCast(self.scalar_type(), result.scalar_type()),
1878 "result type ", self.scalar_type(), " can't be cast to the "
1879 "desired output type ", result.scalar_type());
1880 warn_invalid_degrees_of_freedom(fname, iter, correction);
1881
1882 if (iter.numel() == 0) {
1883 // Trivial reduction
1884 result.fill_(std::numeric_limits<double>::quiet_NaN());
1885 return result;
1886 } else if (
1887 result.numel() == 1 && iter.device_type() == kCPU &&
1888 iter.common_dtype() != kBFloat16 && iter.common_dtype() != kHalf) {
1889 // NOTE: CPU performance significantly regressed when attempting to port to
1890 // ATen,
1891 // so all-reduce has a custom implementation.
1892 // See https://github.com/pytorch/pytorch/pull/43858.
1893 result.fill_(std_var_all_cpu(self, correction, take_sqrt));
1894 } else {
1895 std_var_stub(iter.device_type(), iter, correction, take_sqrt);
1896 }
1897 return result;
1898 }
1899
std_var_mean_out(const char * fname,Tensor & result1,Tensor & result2,const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction_opt,bool keepdim,bool take_sqrt)1900 static std::tuple<Tensor&, Tensor&> std_var_mean_out(
1901 const char* fname, Tensor& result1, Tensor& result2, const Tensor& self,
1902 at::OptionalIntArrayRef dim, const std::optional<Scalar>& correction_opt,
1903 bool keepdim, bool take_sqrt) {
1904 AT_ASSERT(result1.defined() && result2.defined());
1905 TORCH_CHECK(self.device().is_cpu() || self.is_cuda() || self.is_xpu(),
1906 fname, " supports tensors on a CPU, CUDA, or XPU device only, got: ",
1907 self.device().type());
1908 TORCH_CHECK(self.layout() == Layout::Strided,
1909 fname, " only supports strided layout, got: ", self.layout());
1910 TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
1911 fname, " only support floating point and complex dtypes");
1912 TORCH_CHECK(result1.scalar_type() == c10::toRealValueType(result2.scalar_type()),
1913 fname, " expected result1 to be real and match the precision of result2. Got ",
1914 result1.scalar_type(), " and ", result2.scalar_type(), ".");
1915
1916 if (at::isComplexType(self.scalar_type())) {
1917 // For complex, calculate for real and imaginary components separately then combine as:
1918 // variance = var_real + var_imag
1919 // mean = mean_real + j * mean_imag
1920 ScalarType dtype = c10::toRealValueType(get_dtype_from_result(result1, {}));
1921 Tensor real_in = at::real(self);
1922 Tensor real_out_var = at::empty({0}, self.options().dtype(dtype));
1923 Tensor real_out_mean = at::empty({0}, self.options().dtype(dtype));
1924 std_var_mean_out(
1925 fname,
1926 real_out_var,
1927 real_out_mean,
1928 real_in,
1929 dim,
1930 correction_opt,
1931 keepdim,
1932 /*take_sqrt=*/false);
1933
1934 Tensor imag_in = at::imag(self);
1935 Tensor imag_out_var = at::empty({0}, self.options().dtype(dtype));
1936 Tensor imag_out_mean = at::empty({0}, self.options().dtype(dtype));
1937 std_var_mean_out(
1938 fname,
1939 imag_out_var,
1940 imag_out_mean,
1941 imag_in,
1942 dim,
1943 correction_opt,
1944 keepdim,
1945 /*take_sqrt=*/false);
1946
1947 at::add_out(result1, real_out_var, imag_out_var);
1948 if (take_sqrt) {
1949 at::sqrt_out(result1, result1);
1950 }
1951 at::complex_out(result2, real_out_mean, imag_out_mean);
1952 return std::tuple<Tensor&, Tensor&>(result1, result2);
1953 }
1954
1955 // Computation for floating point
1956 const auto correction = correction_opt.value_or(1).toDouble();
1957 ScalarType dtype = get_dtype_from_result(result1, {});
1958 auto iter =
1959 make_reduction(fname, result1, result2, self, dim, keepdim, dtype);
1960 warn_invalid_degrees_of_freedom(fname, iter, correction);
1961
1962 if (iter.numel() == 0) {
1963 // Trivial reduction
1964 result1.fill_(std::numeric_limits<double>::quiet_NaN());
1965 result2.fill_(std::numeric_limits<double>::quiet_NaN());
1966 } else {
1967 std_var_stub(iter.device_type(), iter, correction, take_sqrt);
1968 }
1969 return std::tuple<Tensor&, Tensor&>(result1, result2);
1970 }
1971
var_mean(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)1972 std::tuple<Tensor, Tensor> var_mean(
1973 const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
1974 return at::var_mean(
1975 self, /*dim=*/at::OptionalIntArrayRef(dim),
1976 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
1977 keepdim);
1978 }
1979
std_mean(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)1980 std::tuple<Tensor, Tensor> std_mean(
1981 const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
1982 return at::std_mean(
1983 self, /*dim=*/at::OptionalIntArrayRef(dim),
1984 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
1985 keepdim);
1986 }
1987
std_mean(const Tensor & self,bool unbiased)1988 std::tuple<Tensor, Tensor> std_mean(const Tensor& self, bool unbiased) {
1989 return at::std_mean(
1990 self, /*dim=*/std::nullopt,
1991 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
1992 }
1993
var_mean(const Tensor & self,bool unbiased)1994 std::tuple<Tensor, Tensor> var_mean(const Tensor& self, bool unbiased) {
1995 return at::var_mean(
1996 self, /*dim=*/std::nullopt,
1997 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
1998 }
var_mean_out(Tensor & result1,Tensor & result2,const Tensor & self,IntArrayRef dim,int64_t correction,bool keepdim)1999 std::tuple<Tensor&, Tensor&> var_mean_out(
2000 Tensor& result1, Tensor& result2, const Tensor& self, IntArrayRef dim,
2001 int64_t correction, bool keepdim) {
2002 return std_var_mean_out(
2003 "var_mean", result1, result2, self, dim, correction, keepdim, false);
2004 }
2005
options_to_value_type(TensorOptions opts)2006 static TensorOptions options_to_value_type(TensorOptions opts) {
2007 auto scalar_type = typeMetaToScalarType(opts.dtype());
2008 return opts.dtype(c10::toRealValueType(scalar_type));
2009 }
2010
var_mean(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim)2011 std::tuple<Tensor, Tensor> var_mean(
2012 const Tensor& self, at::OptionalIntArrayRef dim,
2013 const std::optional<Scalar>& correction, bool keepdim) {
2014 Tensor result1 = at::empty({0}, options_to_value_type(self.options()));
2015 Tensor result2 = at::empty({0}, self.options());
2016 return std_var_mean_out(
2017 "var_mean", result1, result2, self, dim, correction, keepdim, false);
2018 }
2019
std_mean(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim)2020 std::tuple<Tensor, Tensor> std_mean(
2021 const Tensor& self, at::OptionalIntArrayRef dim,
2022 const std::optional<Scalar>& correction, bool keepdim) {
2023 Tensor result1 = at::empty({0}, options_to_value_type(self.options()));
2024 Tensor result2 = at::empty({0}, self.options());
2025 return std_var_mean_out(
2026 "std_mean", result1, result2, self, dim, correction, keepdim, true);
2027 }
2028
var(const Tensor & self,bool unbiased)2029 Tensor var(const Tensor& self, bool unbiased) {
2030 return at::var(
2031 self, /*dim=*/std::nullopt,
2032 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
2033 }
2034
var(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)2035 Tensor var(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
2036 return at::var(
2037 self, /*dim=*/at::OptionalIntArrayRef(dim),
2038 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
2039 keepdim);
2040 }
2041
var_out(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim,Tensor & result)2042 Tensor& var_out(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, Tensor& result) {
2043 return at::var_out(
2044 result, self, /*dim=*/at::OptionalIntArrayRef(dim),
2045 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0),
2046 keepdim);
2047 }
2048
std(const Tensor & self,bool unbiased)2049 Tensor std(const Tensor& self, bool unbiased) {
2050 return at::std(
2051 self, /*dim=*/std::nullopt, /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0));
2052 }
2053
std(const Tensor & self,at::OptionalIntArrayRef dim,bool unbiased,bool keepdim)2054 Tensor std(const Tensor& self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) {
2055 return at::std(self, dim,
2056 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0), keepdim);
2057 }
2058
std_out(const Tensor & self,at::OptionalIntArrayRef opt_dim,bool unbiased,bool keepdim,Tensor & result)2059 Tensor& std_out(const Tensor& self, at::OptionalIntArrayRef opt_dim, bool unbiased, bool keepdim, Tensor& result) {
2060 return at::std_out(result, self, opt_dim,
2061 /*correction=*/std::make_optional<Scalar>(unbiased ? 1 : 0), keepdim);
2062 }
2063
std(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim)2064 Tensor std(const Tensor& self, at::OptionalIntArrayRef dim,
2065 const std::optional<Scalar>& correction, bool keepdim) {
2066 Tensor result = at::empty({0}, options_to_value_type(self.options()));
2067 return std_var_out("std", result, self, dim, correction, keepdim, true);
2068 }
2069
std_out(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim,Tensor & result)2070 Tensor& std_out(
2071 const Tensor& self, at::OptionalIntArrayRef dim,
2072 const std::optional<Scalar>& correction, bool keepdim, Tensor& result) {
2073 return std_var_out("std", result, self, dim, correction, keepdim, true);
2074 }
2075
var_out(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim,Tensor & result)2076 Tensor& var_out(
2077 const Tensor& self, at::OptionalIntArrayRef dim,
2078 const std::optional<Scalar>& correction, bool keepdim, Tensor& result) {
2079 return std_var_out("var", result, self, dim, correction, keepdim, false);
2080 }
2081
var(const Tensor & self,at::OptionalIntArrayRef dim,const std::optional<Scalar> & correction,bool keepdim)2082 Tensor var(
2083 const Tensor& self, at::OptionalIntArrayRef dim,
2084 const std::optional<Scalar>& correction, bool keepdim) {
2085 Tensor result = at::empty({0}, options_to_value_type(self.options()));
2086 return std_var_out("var", result, self, dim, correction, keepdim, false);
2087 }
2088
std(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim)2089 Tensor std(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
2090 return at::std(self, dimnames_to_positions(self, dim), unbiased, keepdim);
2091 }
2092
std_out(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim,Tensor & result)2093 Tensor& std_out(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim, Tensor& result) {
2094 return at::std_out(result, self, dimnames_to_positions(self, dim), unbiased, keepdim);
2095 }
2096
var(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim)2097 Tensor var(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
2098 return at::var(self, dimnames_to_positions(self, dim), unbiased, keepdim);
2099 }
2100
var_out(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim,Tensor & result)2101 Tensor& var_out(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim, Tensor& result) {
2102 return at::var_out(
2103 result, self, dimnames_to_positions(self, dim), unbiased, keepdim);
2104 }
2105
var_mean(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim)2106 std::tuple<Tensor,Tensor> var_mean(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
2107 return at::var_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim);
2108 }
2109
std_mean(const Tensor & self,DimnameList dim,bool unbiased,bool keepdim)2110 std::tuple<Tensor,Tensor> std_mean(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
2111 return at::std_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim);
2112 }
2113
std(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim)2114 Tensor std(const Tensor& self, DimnameList dim, const std::optional<Scalar>& correction, bool keepdim) {
2115 return at::std(self, dimnames_to_positions(self, dim), correction, keepdim);
2116 }
2117
std_out(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim,Tensor & result)2118 Tensor& std_out(const Tensor& self, DimnameList dim, const std::optional<Scalar>& correction,
2119 bool keepdim, Tensor& result) {
2120 return at::std_out(result, self, dimnames_to_positions(self, dim), correction, keepdim);
2121 }
2122
var(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim)2123 Tensor var(const Tensor& self, DimnameList dim, const std::optional<Scalar>& correction, bool keepdim) {
2124 return at::var(self, dimnames_to_positions(self, dim), correction, keepdim);
2125 }
2126
var_out(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim,Tensor & result)2127 Tensor& var_out(const Tensor& self, DimnameList dim, const std::optional<Scalar>& correction,
2128 bool keepdim, Tensor& result) {
2129 return at::var_out(
2130 result, self, dimnames_to_positions(self, dim), correction, keepdim);
2131 }
2132
var_mean(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim)2133 std::tuple<Tensor,Tensor> var_mean(const Tensor& self, DimnameList dim,
2134 const std::optional<Scalar>& correction, bool keepdim) {
2135 return at::var_mean(self, dimnames_to_positions(self, dim), correction, keepdim);
2136 }
2137
std_mean(const Tensor & self,DimnameList dim,const std::optional<Scalar> & correction,bool keepdim)2138 std::tuple<Tensor,Tensor> std_mean(const Tensor& self, DimnameList dim,
2139 const std::optional<Scalar>& correction, bool keepdim) {
2140 return at::std_mean(self, dimnames_to_positions(self, dim), correction, keepdim);
2141 }
2142
norm_out(const Tensor & self,const std::optional<Scalar> & p,DimnameList dim,bool keepdim,ScalarType dtype,Tensor & result)2143 Tensor& norm_out(const Tensor& self, const std::optional<Scalar>& p, DimnameList dim, bool keepdim, ScalarType dtype, Tensor& result) {
2144 return at::norm_out(result, self, p, dimnames_to_positions(self, dim), keepdim, dtype);
2145 }
2146
norm_out(const Tensor & self,const std::optional<Scalar> & p,DimnameList dim,bool keepdim,Tensor & result)2147 Tensor& norm_out(const Tensor& self, const std::optional<Scalar>& p, DimnameList dim, bool keepdim, Tensor& result) {
2148 return at::norm_out(result, self, p, dimnames_to_positions(self, dim), keepdim);
2149 }
2150
norm(const Tensor & self,const std::optional<Scalar> & p,DimnameList dim,bool keepdim,ScalarType dtype)2151 Tensor norm(const Tensor& self, const std::optional<Scalar>& p, DimnameList dim, bool keepdim, ScalarType dtype) {
2152 return at::norm(self, p, dimnames_to_positions(self, dim), keepdim, dtype);
2153 }
2154
norm(const Tensor & self,const std::optional<Scalar> & p,DimnameList dim,bool keepdim)2155 Tensor norm(const Tensor& self, const std::optional<Scalar>& p, DimnameList dim, bool keepdim) {
2156 return at::norm(self, p, dimnames_to_positions(self, dim), keepdim);
2157 }
2158
any(const Tensor & self,Dimname dim,bool keepdim)2159 Tensor any(const Tensor& self, Dimname dim, bool keepdim) {
2160 reportNYIDimnameOverload("any");
2161 }
any_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & result)2162 Tensor& any_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
2163 reportNYIDimnameOverload("any");
2164 }
all(const Tensor & self,Dimname dim,bool keepdim)2165 Tensor all(const Tensor& self, Dimname dim, bool keepdim) {
2166 reportNYIDimnameOverload("all");
2167 }
all_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & result)2168 Tensor& all_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
2169 reportNYIDimnameOverload("all");
2170 }
_is_all_true(const Tensor & self)2171 Tensor _is_all_true(const Tensor& self) {
2172 TORCH_INTERNAL_ASSERT(self.scalar_type() == at::kBool);
2173 return self.all();
2174 }
_is_any_true(const Tensor & self)2175 Tensor _is_any_true(const Tensor& self) {
2176 TORCH_INTERNAL_ASSERT(self.scalar_type() == at::kBool);
2177 return self.any();
2178 }
logcumsumexp(const Tensor & self,Dimname dim)2179 Tensor logcumsumexp(const Tensor& self, Dimname dim) {
2180 return at::logcumsumexp(self, dimname_to_position(self, dim));
2181 }
logcumsumexp_out(const Tensor & self,Dimname dim,Tensor & result)2182 Tensor& logcumsumexp_out(const Tensor& self, Dimname dim, Tensor& result) {
2183 return at::logcumsumexp_out(result, self, dimname_to_position(self, dim));
2184 }
cumsum(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)2185 Tensor cumsum(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
2186 return at::cumsum(self, dimname_to_position(self, dim), dtype);
2187 }
cumsum_(Tensor & self,Dimname dim,std::optional<ScalarType> dtype)2188 Tensor& cumsum_(Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
2189 return at::cumsum_out(self, self, dimname_to_position(self, dim), dtype);
2190 }
cumsum_out(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype,Tensor & result)2191 Tensor& cumsum_out(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype, Tensor& result) {
2192 return at::cumsum_out(result, self, dimname_to_position(self, dim), dtype);
2193 }
cumprod(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)2194 Tensor cumprod(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
2195 return at::cumprod(self, dimname_to_position(self, dim), dtype);
2196 }
cumprod_(Tensor & self,Dimname dim,std::optional<ScalarType> dtype)2197 Tensor& cumprod_(Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
2198 return at::cumprod_out(self, self, dimname_to_position(self, dim), dtype);
2199 }
cumprod_out(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype,Tensor & result)2200 Tensor& cumprod_out(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype, Tensor& result) {
2201 return at::cumprod_out(result, self, dimname_to_position(self, dim), dtype);
2202 }
cummax(const Tensor & self,Dimname dim)2203 std::tuple<Tensor, Tensor> cummax(const Tensor& self, Dimname dim) {
2204 return at::cummax(self, dimname_to_position(self, dim));
2205 }
cummax_out(const Tensor & self,Dimname dim,Tensor & values,Tensor & indices)2206 std::tuple<Tensor&, Tensor&> cummax_out(const Tensor& self, Dimname dim, Tensor& values, Tensor& indices) {
2207 return at::cummax_out(values, indices, self, dimname_to_position(self, dim));
2208 }
cummin(const Tensor & self,Dimname dim)2209 std::tuple<Tensor, Tensor> cummin(const Tensor& self, Dimname dim) {
2210 return at::cummin(self, dimname_to_position(self, dim));
2211 }
cummin_out(const Tensor & self,Dimname dim,Tensor & values,Tensor & indices)2212 std::tuple<Tensor&, Tensor&> cummin_out(const Tensor& self, Dimname dim, Tensor& values, Tensor& indices) {
2213 return at::cummin_out(values, indices, self, dimname_to_position(self, dim));
2214 }
2215
dist(const Tensor & self,const Tensor & other,const Scalar & p)2216 Tensor dist(const Tensor &self, const Tensor& other, const Scalar& p){
2217 return at::norm(self - other, p);
2218 }
2219
cpu_equal(const Tensor & self,const Tensor & other)2220 bool cpu_equal(const Tensor& self, const Tensor& other) {
2221 if (!at::namedinference::are_names_equal(
2222 self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) {
2223 return false;
2224 }
2225 at::NoNamesGuard guard;
2226 TORCH_CHECK(self.device() == other.device(), "Cannot compare two tensors on "
2227 "different devices. Got: ", self.device(), " and ", other.device());
2228 if (!self.is_same_size(other)) {
2229 return false;
2230 }
2231 // Since the flags like neg/conj should be already handled outside the
2232 // TensorIterator, it should be safe to have the following fast path by
2233 // ensuring the storage and strides exactly the same.
2234 if (self.is_alias_of(other)
2235 && self.storage_offset() == other.storage_offset()
2236 && self.dtype() == other.dtype()
2237 && self.is_contiguous() == other.is_contiguous()
2238 && self.strides().equals(other.strides())
2239 // Extra checks to ensure the safety in case cpu_equal is directly called in C++.
2240 && self.layout() == other.layout()
2241 && self.is_neg() == other.is_neg()
2242 && self.is_conj() == other.is_conj()) {
2243 if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
2244 return true;
2245 }
2246 std::atomic<bool> result{true};
2247 auto iter = TensorIteratorConfig().add_const_input(self).build();
2248 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "equal_notnan_cpu", [&] {
2249 iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
2250 if (!result) {
2251 return;
2252 }
2253 char* self_data = data[0];
2254 for (C10_UNUSED const auto i : c10::irange(dim_size)) {
2255 if (isnan_(c10::load<scalar_t>(self_data))) {
2256 result = false;
2257 return;
2258 }
2259 self_data += strides[0];
2260 }
2261 });
2262 });
2263 return result.load();
2264 }
2265
2266 std::atomic<bool> result{true};
2267 auto iter = TensorIteratorConfig()
2268 .add_const_input(self)
2269 .add_const_input(other)
2270 .allow_cpu_scalars(true)
2271 .promote_inputs_to_common_dtype(true)
2272 .build();
2273
2274 AT_DISPATCH_V2(iter.input_dtype(), "equal_cpu", AT_WRAP([&] {
2275 iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
2276 if (!result) {
2277 return;
2278 }
2279 char* self_data = data[0];
2280 char* other_data = data[1];
2281 for (C10_UNUSED const auto i : c10::irange(dim_size)) {
2282 if (c10::load<scalar_t>(self_data) != c10::load<scalar_t>(other_data)) {
2283 result = false;
2284 return;
2285 }
2286 self_data += strides[0];
2287 other_data += strides[1];
2288 }
2289 });
2290 }), kBool, kBFloat16, kHalf, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
2291 return result.load();
2292 }
2293
2294 // max(dim), min(dim), topk(dim), mode(dim), are examples of reduction
2295 // functions that select values. value_selecting_reduction_backward is the
2296 // backward function for those operators; it propagates the grad to the
2297 // specific value locations referred to at `indices`.
value_selecting_reduction_backward_symint(const Tensor & grad,int64_t dim,const Tensor & indices,c10::SymIntArrayRef sizes,bool keepdim)2298 Tensor value_selecting_reduction_backward_symint(const Tensor& grad, int64_t dim, const Tensor& indices, c10::SymIntArrayRef sizes, bool keepdim) {
2299 auto inplace_scatter_if_not_tensor_subclass =
2300 [&](const Tensor& grad_out, const Tensor& indices_) {
2301 auto grad_in = at::zeros_symint(sizes, grad_out.options());
2302 if (areAnyTensorSubclassLike({grad, indices})) {
2303 return grad_in.scatter(dim, indices_, grad_out);
2304 }
2305 return grad_in.scatter_(dim, indices_, grad_out);
2306 };
2307
2308 if (!keepdim && !sizes.empty()) {
2309 auto grad_ = grad.unsqueeze(dim);
2310 auto indices_ = indices.unsqueeze(dim);
2311 return inplace_scatter_if_not_tensor_subclass(grad_, indices_);
2312 }
2313 return inplace_scatter_if_not_tensor_subclass(grad, indices);
2314 }
2315
sum_csr(const Tensor & self,std::optional<ScalarType> dtype)2316 Tensor sum_csr(const Tensor &self, std::optional<ScalarType> dtype) {
2317 return self.values().sum(dtype);
2318 }
2319
sum_coo(const Tensor & self,std::optional<ScalarType> dtype)2320 Tensor sum_coo(const Tensor &self, std::optional<ScalarType> dtype) {
2321 return self._values().sum(dtype);
2322 }
2323
sum_sparse_coo(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)2324 Tensor sum_sparse_coo(const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, std::optional<ScalarType> dtype) {
2325 Tensor result;
2326 if (dim.has_value()) {
2327 if (dtype.has_value()) {
2328 result = at::_sparse_sum(self, *dim, *dtype);
2329 } else {
2330 if (c10::isIntegralType(self.scalar_type(), true)) {
2331 result = at::_sparse_sum(self, *dim, at::kLong);
2332 } else {
2333 result = at::_sparse_sum(self, *dim);
2334 }
2335 }
2336 } else {
2337 result = sum_coo(self, dtype);
2338 }
2339 if (keepdim) {
2340 auto dim_mask = make_dim_mask(dim, self.dim());
2341 for (int dim = 0; dim < self.dim(); dim++) {
2342 if (dim_mask[dim]) {
2343 result = result.unsqueeze(dim);
2344 }
2345 }
2346 }
2347 return result;
2348 }
2349
sum_sparse_compressed(const Tensor & self,at::OptionalIntArrayRef dim,bool keepdim,std::optional<ScalarType> dtype)2350 Tensor sum_sparse_compressed(
2351 const Tensor& self,
2352 at::OptionalIntArrayRef dim,
2353 bool keepdim,
2354 std::optional<ScalarType> dtype) {
2355 // TODO: The signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype is a little
2356 // bit different in the second parameters `dim`, which causes the conversion of `dim`
2357 // to call into `_sparse_csr_sum`. Align the signatures would be a better choice.
2358 TORCH_CHECK(
2359 dim.has_value(), "dim has no value, cannot be used in sum.dim_IntList");
2360 auto layout = self.layout();
2361 TORCH_CHECK(
2362 layout == kSparseCsr,
2363 "Currently the only compressed sparse format supported for sum.dim_IntList is CSR, but got layout ",
2364 layout)
2365 return at::_sparse_csr_sum(self, *dim, keepdim, dtype);
2366 }
2367
2368 } // namespace at::native
2369