1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/MemoryOverlap.h>
6 #include <ATen/NamedTensorUtils.h>
7 #include <ATen/NumericUtils.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/ScalarOps.h>
10 #include <ATen/TensorIterator.h>
11 #include <ATen/TensorMeta.h>
12 #include <ATen/TensorOperators.h>
13 #include <ATen/TensorUtils.h>
14 #include <ATen/TensorSubclassLikeUtils.h>
15 #include <ATen/WrapDimUtils.h>
16 #include <ATen/native/Resize.h>
17 #include <ATen/native/Sorting.h>
18 #include <ATen/native/SortingUtils.h>
19 #include <ATen/native/ReduceOpsUtils.h>
20 #include <c10/util/irange.h>
21
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/arange.h>
27 #include <ATen/ops/argsort_native.h>
28 #include <ATen/ops/broadcast_tensors.h>
29 #include <ATen/ops/empty.h>
30 #include <ATen/ops/full.h>
31 #include <ATen/ops/full_like.h>
32 #include <ATen/ops/kthvalue.h>
33 #include <ATen/ops/kthvalue_native.h>
34 #include <ATen/ops/masked_fill.h>
35 #include <ATen/ops/median.h>
36 #include <ATen/ops/median_native.h>
37 #include <ATen/ops/msort_native.h>
38 #include <ATen/ops/nanmedian.h>
39 #include <ATen/ops/nanmedian_native.h>
40 #include <ATen/ops/nanquantile_native.h>
41 #include <ATen/ops/quantile_native.h>
42 #include <ATen/ops/scalar_tensor.h>
43 #include <ATen/ops/sort.h>
44 #include <ATen/ops/sort_native.h>
45 #include <ATen/ops/topk_native.h>
46 #endif
47
48 #include <utility>
49
50 namespace at::meta {
51
52 using namespace ::at::native;
53
TORCH_META_FUNC(topk)54 TORCH_META_FUNC(topk)
55 (const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted) {
56 int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
57 TORCH_CHECK(
58 k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
59 "selected index k out of range");
60 int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
61 TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");
62
63 // Build the output size, which is the dim being selected set to
64 // size k
65 DimVector topKSize(self.sizes().vec());
66 if (!topKSize.empty()) {
67 topKSize[dim] = k;
68 }
69 set_output_raw_strided(0, topKSize, {}, self.options());
70 set_output_raw_strided(1, topKSize, {}, self.options().dtype(at::kLong));
71 }
72
TORCH_META_FUNC2(sort,stable)73 TORCH_META_FUNC2(sort, stable)
74 (const Tensor& self, std::optional<bool> stable, int64_t dim, bool descending) {
75 maybe_wrap_dim(dim, self.dim());
76
77 // See issue: https://github.com/pytorch/pytorch/issues/65863
78 // Strides should be dense, so as not to allocate too much memory.
79 // We either use 'self' strides, or infer dense strides from them.
80 std::vector<int64_t> strides = (self.is_non_overlapping_and_dense())
81 ? self.strides().vec()
82 : at::infer_dense_strides(self.sizes(), self.strides());
83
84 set_output_raw_strided(0, self.sizes(), strides, self.options(), {});
85 set_output_raw_strided(1, self.sizes(), strides, self.options().dtype(kLong), {});
86 }
87
88 } // namespace at::meta
89
90 namespace at::native {
91
92 DEFINE_DISPATCH(sort_stub);
93 DEFINE_DISPATCH(topk_stub);
94
_fill_indices(const TensorBase & indices,int64_t dim)95 void _fill_indices(const TensorBase &indices, int64_t dim) {
96 auto ndim = indices.dim();
97 assert(0 <= dim && dim < ndim);
98 auto dim_size = indices.size(dim);
99 auto idx_dim = at::arange(0, dim_size, indices.options().dtype(at::kLong));
100 auto idx_dim_sizes = std::vector<int64_t>(ndim, 1);
101 auto idx_dim_strides = std::vector<int64_t>(ndim, 0);
102 idx_dim_sizes[dim] = dim_size;
103 idx_dim_strides[dim] = 1;
104 auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
105 OptionalTensorRef(indices)->copy_(idx_dim_restrided);
106 }
107
108 namespace {
109
110 /* Note from TH:
111 I cut and pasted (slightly adapted) the quicksort code from
112 Sedgewick's 1978 "Implementing Quicksort Programs" article
113 http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf
114
115 It is the state of the art existing implementation. The macros
116 are here to make as close a match as possible to the pseudocode of
117 Program 2 p.851
118
119 Note that other partition schemes exist, and are typically presented
120 in textbook, but those are less efficient. See e.g.
121 http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto
122
123 Julien, November 12th 2013
124 */
125 template <typename scalar_t, typename Comp, typename Fn>
quick_select_template(TensorAccessor<scalar_t,1> arr,int64_t k,Comp gt_or_nan,Fn swap_fn)126 void quick_select_template(
127 TensorAccessor<scalar_t, 1> arr,
128 int64_t k,
129 Comp gt_or_nan,
130 Fn swap_fn) {
131 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
132 int64_t P, L, R, i, j;
133 scalar_t piv;
134 L = 0;
135 R = arr.size(0) - 1;
136
137 do {
138 if (R <= L) // One element only
139 return;
140
141 if (R == L + 1) { // Two elements only
142 if (gt_or_nan(arr[L], arr[R])) {
143 swap_fn(L, R);
144 }
145 return;
146 }
147
148 // Use median of three for pivot choice
149 P = L + (R - L) / 2;
150 swap_fn(P, L + 1);
151 if (gt_or_nan(arr[L + 1], arr[R])) {
152 swap_fn(L + 1, R);
153 }
154 if (gt_or_nan(arr[L], arr[R])) {
155 swap_fn(L, R);
156 }
157 if (gt_or_nan(arr[L + 1], arr[L])) {
158 swap_fn(L + 1, L);
159 }
160
161 i = L + 1;
162 j = R;
163 piv = arr[L];
164 do {
165 do
166 i++;
167 while (gt_or_nan(piv, arr[i]));
168 do
169 j--;
170 while (gt_or_nan(arr[j], piv));
171 if (j < i)
172 break;
173 swap_fn(i, j);
174 } while (true);
175 swap_fn(L, j);
176
177 // Re-set active partition
178 if (j <= k)
179 L = i;
180 if (j >= k)
181 R = j - 1;
182 } while (true);
183 }
184
185 namespace {
186
get_quantile_interpolation_mode(const c10::string_view interpolation)187 QUANTILE_INTERPOLATION_MODE get_quantile_interpolation_mode(
188 const c10::string_view interpolation) {
189 if (interpolation == "linear") {
190 return QUANTILE_INTERPOLATION_MODE::LINEAR;
191 } else if (interpolation == "lower") {
192 return QUANTILE_INTERPOLATION_MODE::LOWER;
193 } else if (interpolation == "higher") {
194 return QUANTILE_INTERPOLATION_MODE::HIGHER;
195 } else if (interpolation == "midpoint") {
196 return QUANTILE_INTERPOLATION_MODE::MIDPOINT;
197 } else if (interpolation == "nearest") {
198 return QUANTILE_INTERPOLATION_MODE::NEAREST;
199 } else {
200 TORCH_CHECK(
201 false,
202 "quantile() interpolation must be one of linear, lower, higher, midpoint or nearest, but got ",
203 interpolation);
204 }
205 }
206
quantile_checks(const Tensor & self,const Tensor & q)207 void quantile_checks(const Tensor& self, const Tensor& q) {
208 TORCH_CHECK(self.numel() > 0, "quantile() input tensor must be non-empty");
209 TORCH_CHECK(q.dim() <= 1, "quantile() q must be a scalar or 1D tensor");
210 TORCH_CHECK(
211 self.scalar_type() == kFloat || self.scalar_type() == kDouble,
212 "quantile() input tensor must be either float or double dtype");
213 TORCH_CHECK(
214 self.scalar_type() == q.scalar_type(),
215 "quantile() q tensor must be same dtype as the input tensor");
216 TORCH_CHECK(
217 self.device() == q.device(),
218 "quantile() q tensor must be on the same device as the input tensor");
219 }
220
quantile_output_shape(const std::optional<int64_t> original_dim,const Tensor & self,const Tensor & q,const bool keepdim,int64_t wrapped_dim)221 std::vector<int64_t> quantile_output_shape(
222 const std::optional<int64_t> original_dim,
223 const Tensor& self,
224 const Tensor& q,
225 const bool keepdim,
226 int64_t wrapped_dim) {
227 // Compute output shape: q_size + reduced_size
228 std::vector<int64_t> out_shape;
229 if (original_dim && self.dim() > 0) {
230 out_shape = self.sizes().vec();
231 if (keepdim) {
232 out_shape[wrapped_dim] = 1;
233 } else {
234 out_shape.erase(out_shape.begin() + wrapped_dim);
235 }
236 } else if (keepdim) {
237 out_shape = std::vector<int64_t>(self.dim(), 1);
238 }
239 if (q.dim() > 0) {
240 out_shape.insert(out_shape.begin(), q.numel());
241 }
242
243 return out_shape;
244 }
245
quantile_compute(const Tensor & self,const Tensor & q,const std::optional<int64_t> orginal_dim,const bool keepdim,const QUANTILE_INTERPOLATION_MODE & interpolation,const bool ignore_nan,int64_t wrapped_dim,std::vector<int64_t> out_shape)246 Tensor quantile_compute(
247 const Tensor& self,
248 const Tensor& q,
249 const std::optional<int64_t> orginal_dim,
250 const bool keepdim,
251 const QUANTILE_INTERPOLATION_MODE& interpolation,
252 const bool ignore_nan,
253 int64_t wrapped_dim,
254 std::vector<int64_t> out_shape) {
255 // Checks that all q values are between 0 and 1, inclusive
256 // NOTE: this check is only performed when running on the CPU to avoid
257 // synchronizing an accelerator with the CPU
258 if (self.device().is_cpu()) {
259 auto all_q_in_range = q.ge(0).logical_and_(q.le(1)).all();
260 TORCH_CHECK(at::is_scalar_tensor_true(all_q_in_range),
261 "quantile() q values must be in the range [0, 1]");
262 }
263
264 // Flatten input if no dim provided else move dim to reduce as last dimension.
265 // Sort to efficiently query kth values.
266 Tensor sorted;
267 if (!orginal_dim) {
268 sorted = std::get<0>(self.flatten().sort());
269 } else if (wrapped_dim == self.dim() - 1) {
270 sorted = std::get<0>(self.sort());
271 } else {
272 sorted = std::get<0>(self.unsqueeze(-1).transpose(wrapped_dim, -1).sort());
273 }
274
275 // Treat q as a 1D tensor for the following computations
276 if (q.dim() == 0) {
277 out_shape.insert(out_shape.begin(), q.numel());
278 }
279
280 // View input as reduced_size + size of dim to reduce
281 std::vector<int64_t> in_shape(out_shape.size());
282 std::copy(out_shape.begin() + 1, out_shape.end(), in_shape.begin());
283 in_shape[in_shape.size() - 1] = sorted.size(-1);
284 sorted = sorted.view(in_shape);
285
286 // Ensure converting from int64_t to double won't overflow
287 TORCH_CHECK(
288 sorted.size(-1) <= std::pow(2, 24),
289 "quantile() input tensor is too large");
290
291 // Convert q in [0, 1] to ranks in [0, reduction_size)
292 Tensor ranks;
293 if (ignore_nan) {
294 // For nanquantile, compute ranks based on number of non-nan values.
295 // If all values are nan, set rank to 0 so the quantile computed is nan.
296 ranks = q * (sorted.isnan().logical_not_().sum(-1, true) - 1);
297 // For Composite Compliance,
298 // if `ranks` is `CCT` but it's tangent is a regular Tensor,
299 // then while computing jvp, we end calling `masked_fill_`
300 // on a regular Tensor with CCT args, so we call
301 // `masked_fill` instead.
302 if (isTensorSubclassLike(ranks) && ranks._fw_grad(/*level=*/0).defined()) {
303 ranks = ranks.masked_fill(ranks < 0, 0);
304 } else {
305 ranks.masked_fill_(ranks < 0, 0);
306 }
307 } else {
308 // For quantile, compute ranks based on reduction size. If there is nan
309 // set rank to last index so the quantile computed will be nan.
310 int64_t last_index = sorted.size(-1) - 1;
311 std::vector<Tensor> tl =
312 at::broadcast_tensors({q * last_index, sorted.isnan().any(-1, true)});
313 ranks = at::masked_fill(tl[0], tl[1], last_index);
314 }
315
316 // adjust ranks based on the interpolation mode
317 if (interpolation == QUANTILE_INTERPOLATION_MODE::LOWER) {
318 ranks.floor_();
319 } else if (interpolation == QUANTILE_INTERPOLATION_MODE::HIGHER) {
320 ranks.ceil_();
321 } else if (interpolation == QUANTILE_INTERPOLATION_MODE::NEAREST) {
322 ranks.round_();
323 }
324
325 Tensor ranks_below = ranks.toType(kLong);
326 Tensor values_below = sorted.gather(-1, ranks_below);
327
328 // Actual interpolation is only needed for the liner and midpoint modes
329 if (interpolation == QUANTILE_INTERPOLATION_MODE::LINEAR ||
330 interpolation == QUANTILE_INTERPOLATION_MODE::MIDPOINT) {
331 // calculate weights for linear and midpoint
332 Tensor weights = interpolation == QUANTILE_INTERPOLATION_MODE::MIDPOINT
333 ? at::full_like(ranks, 0.5)
334 : ranks - ranks_below;
335
336 // Interpolate to compute quantiles and store in values_below
337 Tensor ranks_above = ranks.ceil_().toType(kLong);
338 Tensor values_above = sorted.gather(-1, ranks_above);
339 // For Composite Compliance,
340 // if either `values_below`, `values_above` or `weights` are a CCT
341 // or tangents of `value_above` and `weights` are a CCT,
342 // but if the tangent of `value_below` is a regular Tensor,
343 // then while computing jvp, we will end-up copying a `CCT`,
344 // into regular Tensor. So we use out-of-place variant of `lerp`
345 auto is_primal_cct =
346 areAnyTensorSubclassLike({values_below, values_above, weights});
347 auto is_tangent_cct = areAnyTensorSubclassLike(
348 {values_above._fw_grad(/*level=*/0), weights._fw_grad(/*level=*/0)});
349 if ((is_primal_cct || is_tangent_cct) &&
350 values_below._fw_grad(/*level=*/0).defined() &&
351 !isTensorSubclassLike(values_below._fw_grad(/*level=*/0))) {
352 values_below = values_below.lerp(values_above, weights);
353 } else {
354 values_below.lerp_(values_above, weights);
355 }
356 }
357
358 if (q.dim() == 0) {
359 // If q is scalar, remove last dim to match out shape
360 values_below.squeeze_(-1);
361 } else {
362 // Move quantiles to first dim to match out shape
363 values_below.unsqueeze_(0).transpose_(0, -1).squeeze_(-1);
364 }
365
366 return values_below;
367 }
368
369 } // namespace
370
quantile_out_impl(Tensor & out,const Tensor & self,const Tensor & q,const std::optional<int64_t> original_dim,const bool keepdim,const QUANTILE_INTERPOLATION_MODE & interpolation,const bool ignore_nan)371 void quantile_out_impl(
372 Tensor& out,
373 const Tensor& self,
374 const Tensor& q,
375 const std::optional<int64_t> original_dim,
376 const bool keepdim,
377 const QUANTILE_INTERPOLATION_MODE& interpolation,
378 const bool ignore_nan) {
379 quantile_checks(self, q);
380 TORCH_CHECK(
381 self.scalar_type() == out.scalar_type(),
382 "quantile() out tensor must be same dtype as the input tensor");
383 TORCH_CHECK(
384 self.device() == out.device(),
385 "quantile() out tensor must be on the same device as the input tensor");
386
387 int64_t wrapped_dim = at::maybe_wrap_dim(original_dim.value_or(0), self.dim());
388
389 auto out_shape = quantile_output_shape(original_dim, self, q, keepdim, wrapped_dim);
390 resize_output(out, out_shape);
391
392 auto quantile = quantile_compute(
393 self, q, original_dim, keepdim, interpolation, ignore_nan, wrapped_dim, std::move(out_shape));
394 out.copy_(quantile);
395 }
396
quantile_impl(const Tensor & self,const Tensor & q,const std::optional<int64_t> original_dim,const bool keepdim,const QUANTILE_INTERPOLATION_MODE & interpolation,const bool ignore_nan)397 Tensor quantile_impl(
398 const Tensor& self,
399 const Tensor& q,
400 const std::optional<int64_t> original_dim,
401 const bool keepdim,
402 const QUANTILE_INTERPOLATION_MODE& interpolation,
403 const bool ignore_nan) {
404 quantile_checks(self, q);
405
406 int64_t wrapped_dim = at::maybe_wrap_dim(original_dim.value_or(0), self.dim());
407
408 auto out_shape = quantile_output_shape(original_dim, self, q, keepdim, wrapped_dim);
409
410 return quantile_compute(
411 self, q, original_dim, keepdim, interpolation, ignore_nan, wrapped_dim, std::move(out_shape));
412 }
413
kthvalue_out_impl_cpu(Tensor & values,Tensor & indices,const Tensor & self,int64_t k,int64_t dim_,bool keepdim)414 std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cpu(
415 Tensor& values,
416 Tensor& indices,
417 const Tensor& self,
418 int64_t k,
419 int64_t dim_,
420 bool keepdim) {
421 int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
422 int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
423 zero_numel_check_dims(self, dim, "kthvalue()");
424
425 TORCH_CHECK(k >= 1 && k <= slicesize,
426 "kthvalue(): selected number k out of range for dimension ", dim);
427
428 at::assert_no_overlap(self, values);
429
430 _reduction_with_indices_allocate_or_resize_output(
431 values, indices, self, dim_, keepdim);
432 if (self.dim() == 0 && self.numel() == 1) {
433 values.copy_(self);
434 indices.zero_();
435 return std::forward_as_tuple(values, indices);
436 }
437 auto tmp_values = self.clone(at::MemoryFormat::Contiguous);
438 auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong));
439
440 auto tmp_values_stride = tmp_values.strides()[dim];
441 auto tmp_indices_stride = tmp_indices.strides()[dim];
442 auto sizes = self.sizes();
443
444 TORCH_CHECK(indices.scalar_type() == kLong);
445
446 auto iter = TensorIteratorConfig()
447 .check_all_same_dtype(false)
448 .resize_outputs(false)
449 .declare_static_shape(sizes, /*squash_dims=*/dim)
450 .add_output(tmp_values)
451 .add_output(tmp_indices)
452 .add_output(values)
453 .add_output(indices)
454 .build();
455
456 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "kthvalue_cpu", [&] {
457 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
458 for (const auto i : c10::irange(n)) {
459 TensorAccessor<scalar_t, 1> tmp_values(
460 reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
461 &sizes[dim], &tmp_values_stride);
462 TensorAccessor<int64_t, 1> tmp_indices(
463 reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
464 &sizes[dim], &tmp_indices_stride);
465 auto mode_value = reinterpret_cast<scalar_t*>(data[2] + i * strides[2]);
466 auto mode_index = reinterpret_cast<int64_t*>(data[3] + i * strides[3]);
467
468 for (const auto j : c10::irange(tmp_indices.size(0))) {
469 tmp_indices[j] = j;
470 }
471
472 // we want NaN to be sorted as top for numpy compatibility
473 quick_select_template(
474 tmp_values,
475 k - 1,
476 [](scalar_t x, scalar_t y) -> bool {
477 return (
478 (_isnan<scalar_t>(x) && !_isnan<scalar_t>(y)) || (x > y));
479 },
480 [&](int64_t i, int64_t j) {
481 std::swap(tmp_values[i], tmp_values[j]);
482 std::swap(tmp_indices[i], tmp_indices[j]);
483 });
484 *mode_value = tmp_values[k - 1];
485 *mode_index = tmp_indices[k - 1];
486 }
487 };
488
489 int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
490 iter.for_each(loop, /*grain_size=*/grain_size);
491 });
492
493 if (!keepdim) {
494 values.squeeze_(dim);
495 indices.squeeze_(dim);
496 }
497 return std::forward_as_tuple(values, indices);
498 }
499
500 // Computes both the median and its index along dimension dim of the input
median_with_indices_impl(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim,bool keepdim,bool ignore_nan)501 std::tuple<Tensor&, Tensor&> median_with_indices_impl(
502 Tensor& values,
503 Tensor& indices,
504 const Tensor& self,
505 int64_t dim,
506 bool keepdim,
507 bool ignore_nan) {
508 dim = at::maybe_wrap_dim(dim, self.dim());
509
510 int64_t size = self.dim() > 0 ? self.size(dim) : 1;
511 zero_numel_check_dims(self, dim, "median()");
512
513 checkDeviceType("median", {values, indices}, self.device().type());
514 checkScalarType("median", {indices, "indices", 1}, kLong);
515 checkSameType("median", {values, "values", 0}, {self, "self", 2});
516
517 std::vector<int64_t> out_shape = self.sizes().vec();
518 if (self.dim() > 0) {
519 if (keepdim) {
520 out_shape[dim] = 1;
521 } else {
522 out_shape.erase(out_shape.begin() + dim);
523 }
524 }
525
526 resize_output(values, out_shape);
527 resize_output(indices, out_shape);
528
529 // Ensure #dim is the same for all tensors required for dim_apply
530 Tensor in = self.dim() > 0 ? self : self.unsqueeze(0);
531 Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim);
532 Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim);
533
534 // Make dim to reduce contiguous (stride=1)
535 if (in.stride(dim) > 1) {
536 in = in.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim).contiguous();
537 vals = vals.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim);
538 inds = inds.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim);
539 dim = in.dim() - 1;
540 }
541
542 auto sizes = in.sizes();
543 auto iter = TensorIteratorConfig()
544 .check_all_same_dtype(false)
545 .resize_outputs(false)
546 .declare_static_shape(sizes, /*squash_dims=*/dim)
547 .add_output(vals)
548 .add_output(inds)
549 .add_const_input(in)
550 .build();
551
552 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, in.scalar_type(), "median_out", [&] {
553 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
554 for (const auto i : c10::irange(n)) {
555 auto valp = reinterpret_cast<scalar_t*>(data[0] + i * strides[0]);
556 auto indp = reinterpret_cast<int64_t*>(data[1] + i * strides[1]);
557 auto ip = reinterpret_cast<const scalar_t*>(data[2] + i * strides[2]);
558
559 // For torch.median, search for NaN and return it if found
560 if (!ignore_nan) {
561 const scalar_t* nanp = std::find_if(ip, ip + size, _isnan<scalar_t>);
562 if (nanp != ip + size) {
563 *valp = *nanp;
564 *indp = nanp - ip;
565 continue;
566 }
567 }
568
569 // Vector of indices for indirectly partitioning input around median
570 std::vector<int64_t> idx(size);
571 auto first = idx.begin();
572 auto last = idx.end();
573 std::iota(first, last, 0);
574
575 // We partition the input around the median indirectly using the indices
576 // vector so that nth points to the index of the median in the unmodified
577 // input tensor.
578 auto nth = first;
579 if (!ignore_nan) {
580 // If we got here, there are no nan values
581 nth += (size - 1) / 2;
582 std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) {
583 return ip[i] < ip[j] || (ip[i] == ip[j] && i < j);
584 });
585 } else {
586 // For torch.nanmedian, compute median of non-nan values only
587 int64_t num_nan = std::count_if(ip, ip + size, _isnan<scalar_t>);
588 nth += (size - num_nan - 1) / 2;
589 std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) {
590 return ip[i] < ip[j] || (ip[i] == ip[j] && i < j) ||
591 (_isnan(ip[j]) && !_isnan(ip[i]));
592 });
593 }
594
595 *valp = ip[*nth];
596 *indp = *nth;
597 }
598 };
599 int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
600 iter.for_each(loop, /*grain_size=*/grain_size);
601 });
602
603 return std::forward_as_tuple(values, indices);
604 }
605
606 // Computes the median of all values in the input
median_impl(const Tensor & self,bool ignore_nan)607 Tensor median_impl(const Tensor& self, bool ignore_nan) {
608 NoNamesGuard guard;
609 const int64_t size = self.numel();
610
611 // Return nan for empty tensors
612 if (size <= 0) {
613 return at::full({}, std::numeric_limits<float>::quiet_NaN()).to(self.options());
614 }
615
616 // Clone the input tensor so we can partition it around the median value
617 Tensor in = self.clone();
618 Tensor out = at::empty({}, self.options());
619
620 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, in.scalar_type(), "median_cpu", [&] {
621 scalar_t* op = out.data_ptr<scalar_t>();
622 scalar_t* first = in.data_ptr<scalar_t>();
623 scalar_t* last = first + size;
624
625 // For torch.median, if there are nan values return nan
626 if (!ignore_nan && std::any_of(first, last, _isnan<scalar_t>)) {
627 *op = std::numeric_limits<scalar_t>::quiet_NaN();
628 return;
629 }
630
631 scalar_t* median = first;
632 if (!ignore_nan) {
633 // If we got here, there are no nan values
634 median += (size - 1) / 2;
635 std::nth_element(first, median, last);
636 } else {
637 // For torch.nanmedian, compute median of non-nan values only
638 int64_t num_nan = std::count_if(first, last, _isnan<scalar_t>);
639 median += (size - num_nan - 1) / 2;
640 std::nth_element(first, median, last, [](scalar_t a, scalar_t b) {
641 return a < b || (_isnan(b) && !_isnan(a));
642 });
643 }
644
645 *op = *median;
646 });
647
648 return out;
649 }
650
651 } // namespace
652
quantile_out(const Tensor & self,const Tensor & q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation,Tensor & out)653 Tensor& quantile_out(
654 const Tensor& self,
655 const Tensor& q,
656 std::optional<int64_t> dim,
657 bool keepdim,
658 const c10::string_view interpolation,
659 Tensor& out) {
660 quantile_out_impl(
661 out,
662 self,
663 q,
664 dim,
665 keepdim,
666 get_quantile_interpolation_mode(interpolation),
667 /*ignore_nan=*/false);
668 return out;
669 }
670
quantile_out(const Tensor & self,double q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation,Tensor & out)671 Tensor& quantile_out(
672 const Tensor& self,
673 double q,
674 std::optional<int64_t> dim,
675 bool keepdim,
676 const c10::string_view interpolation,
677 Tensor& out) {
678 TORCH_CHECK(
679 q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
680 return at::native::quantile_out(
681 self,
682 at::scalar_tensor(q, self.options()),
683 dim,
684 keepdim,
685 interpolation,
686 out);
687 }
688
quantile(const Tensor & self,const Tensor & q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation)689 Tensor quantile(
690 const Tensor& self,
691 const Tensor& q,
692 std::optional<int64_t> dim,
693 bool keepdim,
694 const c10::string_view interpolation) {
695 return quantile_impl(
696 self,
697 q,
698 dim,
699 keepdim,
700 get_quantile_interpolation_mode(interpolation),
701 /*ignore_nan=*/false);
702 }
703
quantile(const Tensor & self,double q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation)704 Tensor quantile(
705 const Tensor& self,
706 double q,
707 std::optional<int64_t> dim,
708 bool keepdim,
709 const c10::string_view interpolation) {
710 TORCH_CHECK(
711 q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
712 return at::native::quantile(
713 self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation);
714 }
715
nanquantile_out(const Tensor & self,const Tensor & q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation,Tensor & out)716 Tensor& nanquantile_out(
717 const Tensor& self,
718 const Tensor& q,
719 std::optional<int64_t> dim,
720 bool keepdim,
721 const c10::string_view interpolation,
722 Tensor& out) {
723 quantile_out_impl(
724 out,
725 self,
726 q,
727 dim,
728 keepdim,
729 get_quantile_interpolation_mode(interpolation),
730 /*ignore_nan=*/true);
731 return out;
732 }
733
nanquantile_out(const Tensor & self,double q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation,Tensor & out)734 Tensor& nanquantile_out(
735 const Tensor& self,
736 double q,
737 std::optional<int64_t> dim,
738 bool keepdim,
739 const c10::string_view interpolation,
740 Tensor& out) {
741 TORCH_CHECK(
742 q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
743 return at::native::nanquantile_out(
744 self,
745 at::scalar_tensor(q, self.options()),
746 dim,
747 keepdim,
748 interpolation,
749 out);
750 }
751
nanquantile(const Tensor & self,const Tensor & q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation)752 Tensor nanquantile(
753 const Tensor& self,
754 const Tensor& q,
755 std::optional<int64_t> dim,
756 bool keepdim,
757 const c10::string_view interpolation) {
758 return quantile_impl(
759 self,
760 q,
761 dim,
762 keepdim,
763 get_quantile_interpolation_mode(interpolation),
764 /*ignore_nan=*/true);
765 }
766
nanquantile(const Tensor & self,double q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation)767 Tensor nanquantile(
768 const Tensor& self,
769 double q,
770 std::optional<int64_t> dim,
771 bool keepdim,
772 const c10::string_view interpolation) {
773 TORCH_CHECK(
774 q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
775 return at::native::nanquantile(
776 self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation);
777 }
778
kthvalue_out_cpu(const Tensor & self,int64_t k,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)779 std::tuple<Tensor&, Tensor&> kthvalue_out_cpu(
780 const Tensor& self,
781 int64_t k,
782 int64_t dim,
783 bool keepdim,
784 Tensor& values,
785 Tensor& indices) {
786 auto result = [&]() {
787 NoNamesGuard guard;
788 return kthvalue_out_impl_cpu(values, indices, self, k, dim, keepdim);
789 }();
790 namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
791 namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
792 return result;
793 }
794
kthvalue_out(const Tensor & self,int64_t k,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)795 std::tuple<Tensor&, Tensor&> kthvalue_out(
796 const Tensor& self,
797 int64_t k,
798 Dimname dim,
799 bool keepdim,
800 Tensor& values,
801 Tensor& indices) {
802 return at::kthvalue_out(
803 values, indices, self, k, dimname_to_position(self, dim), keepdim);
804 }
805
kthvalue(const Tensor & self,int64_t k,int64_t dim,bool keepdim)806 std::tuple<Tensor, Tensor> kthvalue(
807 const Tensor& self,
808 int64_t k,
809 int64_t dim,
810 bool keepdim) {
811 Tensor values = at::empty({0}, self.options());
812 Tensor indices = at::empty({0}, self.options().dtype(kLong));
813 at::kthvalue_out(values, indices, self, k, dim, keepdim);
814 return std::make_tuple(values, indices);
815 }
816
kthvalue(const Tensor & self,int64_t k,Dimname dim,bool keepdim)817 std::tuple<Tensor, Tensor> kthvalue(
818 const Tensor& self,
819 int64_t k,
820 Dimname dim,
821 bool keepdim) {
822 return at::kthvalue(self, k, dimname_to_position(self, dim), keepdim);
823 }
824
TORCH_IMPL_FUNC(topk_out_cpu)825 TORCH_IMPL_FUNC(topk_out_cpu)
826 (const Tensor& self,
827 int64_t k,
828 int64_t dim_,
829 bool largest,
830 bool sorted,
831 const Tensor& values,
832 const Tensor& indices) {
833 int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
834 TORCH_CHECK(
835 k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
836 "selected index k out of range");
837
838 if (self.dim() == 0 && self.numel() == 1) {
839 values.copy_(self);
840 indices.zero_();
841 } else {
842 topk_stub(kCPU, values, indices, self, k, dim, largest, sorted);
843 }
844 }
845
median_out_cpu(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)846 std::tuple<Tensor&, Tensor&> median_out_cpu(
847 const Tensor& self,
848 int64_t dim,
849 bool keepdim,
850 Tensor& values,
851 Tensor& indices) {
852 auto result = [&]() {
853 NoNamesGuard guard;
854 return median_with_indices_impl(
855 values, indices, self, dim, keepdim, /*ignore_nan=*/false);
856 }();
857 namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
858 namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
859 return result;
860 }
861
median_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)862 std::tuple<Tensor&, Tensor&> median_out(
863 const Tensor& self,
864 Dimname dim,
865 bool keepdim,
866 Tensor& values,
867 Tensor& indices) {
868 return at::median_out(
869 values, indices, self, dimname_to_position(self, dim), keepdim);
870 }
871
median(const Tensor & self,int64_t dim,bool keepdim)872 std::tuple<Tensor, Tensor> median(
873 const Tensor& self,
874 int64_t dim,
875 bool keepdim) {
876 Tensor values = at::empty({0}, self.options());
877 Tensor indices = at::empty({0}, self.options().dtype(kLong));
878 at::median_out(values, indices, self, dim, keepdim);
879 return std::make_tuple(values, indices);
880 }
881
median(const Tensor & self,Dimname dim,bool keepdim)882 std::tuple<Tensor, Tensor> median(
883 const Tensor& self,
884 Dimname dim,
885 bool keepdim) {
886 return at::median(self, dimname_to_position(self, dim), keepdim);
887 }
888
median_cpu(const Tensor & self)889 Tensor median_cpu(const Tensor& self) {
890 return median_impl(self, /*ignore_nan=*/false);
891 }
892
nanmedian_out_cpu(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)893 std::tuple<Tensor&, Tensor&> nanmedian_out_cpu(
894 const Tensor& self,
895 int64_t dim,
896 bool keepdim,
897 Tensor& values,
898 Tensor& indices) {
899 auto result = [&]() {
900 NoNamesGuard guard;
901 return median_with_indices_impl(
902 values, indices, self, dim, keepdim, /*ignore_nan=*/true);
903 }();
904 namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
905 namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
906 return result;
907 }
908
nanmedian_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)909 std::tuple<Tensor&, Tensor&> nanmedian_out(
910 const Tensor& self,
911 Dimname dim,
912 bool keepdim,
913 Tensor& values,
914 Tensor& indices) {
915 return at::nanmedian_out(
916 values, indices, self, dimname_to_position(self, dim), keepdim);
917 }
918
nanmedian(const Tensor & self,int64_t dim,bool keepdim)919 std::tuple<Tensor, Tensor> nanmedian(
920 const Tensor& self,
921 int64_t dim,
922 bool keepdim) {
923 Tensor values = at::empty({0}, self.options());
924 Tensor indices = at::empty({0}, self.options().dtype(kLong));
925 at::nanmedian_out(values, indices, self, dim, keepdim);
926 return std::make_tuple(values, indices);
927 }
928
nanmedian(const Tensor & self,Dimname dim,bool keepdim)929 std::tuple<Tensor, Tensor> nanmedian(
930 const Tensor& self,
931 Dimname dim,
932 bool keepdim) {
933 return at::nanmedian(self, dimname_to_position(self, dim), keepdim);
934 }
935
nanmedian_cpu(const Tensor & self)936 Tensor nanmedian_cpu(const Tensor& self) {
937 return median_impl(self, /*ignore_nan=*/true);
938 }
939
TORCH_IMPL_FUNC(sort_stable_out)940 TORCH_IMPL_FUNC(sort_stable_out)
941 (const Tensor& self,
942 std::optional<bool> stable,
943 int64_t dim,
944 bool descending,
945 const Tensor& values,
946 const Tensor& indices) {
947 values.copy_(self);
948 // check if self is scalar
949 if (self.dim() == 0 && self.numel() == 1) {
950 indices.zero_();
951 } else {
952 dim = maybe_wrap_dim(dim, self.dim());
953 sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value_or(false));
954 }
955 }
956
sort_out(const Tensor & self,int64_t dim,bool descending,Tensor & values,Tensor & indices)957 std::tuple<Tensor&, Tensor&> sort_out(
958 const Tensor& self,
959 int64_t dim,
960 bool descending,
961 Tensor& values,
962 Tensor& indices) {
963 return at::sort_out(values, indices, self, false, dim, descending);
964 }
965
sort(const Tensor & self,int64_t dim,bool descending)966 std::tuple<Tensor, Tensor> sort(
967 const Tensor& self,
968 int64_t dim,
969 bool descending) {
970 return at::sort(self, false, dim, descending);
971 }
972
msort_out(const Tensor & self,Tensor & values)973 Tensor& msort_out(const Tensor& self, Tensor& values) {
974 Tensor indices = at::empty({0}, self.options().dtype(kLong));
975 at::sort_out(values, indices, self, 0, false);
976 return values;
977 }
978
msort(const Tensor & self)979 Tensor msort(const Tensor& self) {
980 return std::get<0>(at::sort(self, 0, false));
981 }
982
argsort(const Tensor & self,int64_t dim,bool descending)983 Tensor argsort(const Tensor & self, int64_t dim, bool descending) {
984 return std::get<1>(at::sort(self, dim, descending));
985 }
986
argsort(const Tensor & self,bool stable,int64_t dim,bool descending)987 Tensor argsort(const Tensor & self, bool stable, int64_t dim, bool descending) {
988 return std::get<1>(at::sort(self, stable, dim, descending));
989 }
990
argsort_out(const Tensor & self,bool stable,int64_t dim,bool descending,Tensor & out)991 Tensor& argsort_out(const Tensor & self, bool stable, int64_t dim, bool descending, Tensor& out) {
992 auto values = at::empty({0}, self.options());
993 at::sort_outf(self, stable, dim, descending, values, out);
994 return out;
995 }
996
997
998 } // namespace at::native
999