xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Sorting.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/MemoryOverlap.h>
6 #include <ATen/NamedTensorUtils.h>
7 #include <ATen/NumericUtils.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/ScalarOps.h>
10 #include <ATen/TensorIterator.h>
11 #include <ATen/TensorMeta.h>
12 #include <ATen/TensorOperators.h>
13 #include <ATen/TensorUtils.h>
14 #include <ATen/TensorSubclassLikeUtils.h>
15 #include <ATen/WrapDimUtils.h>
16 #include <ATen/native/Resize.h>
17 #include <ATen/native/Sorting.h>
18 #include <ATen/native/SortingUtils.h>
19 #include <ATen/native/ReduceOpsUtils.h>
20 #include <c10/util/irange.h>
21 
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/arange.h>
27 #include <ATen/ops/argsort_native.h>
28 #include <ATen/ops/broadcast_tensors.h>
29 #include <ATen/ops/empty.h>
30 #include <ATen/ops/full.h>
31 #include <ATen/ops/full_like.h>
32 #include <ATen/ops/kthvalue.h>
33 #include <ATen/ops/kthvalue_native.h>
34 #include <ATen/ops/masked_fill.h>
35 #include <ATen/ops/median.h>
36 #include <ATen/ops/median_native.h>
37 #include <ATen/ops/msort_native.h>
38 #include <ATen/ops/nanmedian.h>
39 #include <ATen/ops/nanmedian_native.h>
40 #include <ATen/ops/nanquantile_native.h>
41 #include <ATen/ops/quantile_native.h>
42 #include <ATen/ops/scalar_tensor.h>
43 #include <ATen/ops/sort.h>
44 #include <ATen/ops/sort_native.h>
45 #include <ATen/ops/topk_native.h>
46 #endif
47 
48 #include <utility>
49 
50 namespace at::meta {
51 
52 using namespace ::at::native;
53 
TORCH_META_FUNC(topk)54 TORCH_META_FUNC(topk)
55 (const Tensor& self, int64_t k, int64_t dim_, bool largest, bool sorted) {
56   int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
57   TORCH_CHECK(
58       k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
59       "selected index k out of range");
60   int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
61   TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");
62 
63   // Build the output size, which is the dim being selected set to
64   // size k
65   DimVector topKSize(self.sizes().vec());
66   if (!topKSize.empty()) {
67     topKSize[dim] = k;
68   }
69   set_output_raw_strided(0, topKSize, {}, self.options());
70   set_output_raw_strided(1, topKSize, {}, self.options().dtype(at::kLong));
71 }
72 
TORCH_META_FUNC2(sort,stable)73 TORCH_META_FUNC2(sort, stable)
74 (const Tensor& self, std::optional<bool> stable, int64_t dim, bool descending) {
75   maybe_wrap_dim(dim, self.dim());
76 
77   // See issue: https://github.com/pytorch/pytorch/issues/65863
78   // Strides should be dense, so as not to allocate too much memory.
79   // We either use 'self' strides, or infer dense strides from them.
80   std::vector<int64_t> strides = (self.is_non_overlapping_and_dense())
81       ? self.strides().vec()
82       : at::infer_dense_strides(self.sizes(), self.strides());
83 
84   set_output_raw_strided(0, self.sizes(), strides, self.options(), {});
85   set_output_raw_strided(1, self.sizes(), strides, self.options().dtype(kLong), {});
86 }
87 
88 } // namespace at::meta
89 
90 namespace at::native {
91 
92 DEFINE_DISPATCH(sort_stub);
93 DEFINE_DISPATCH(topk_stub);
94 
_fill_indices(const TensorBase & indices,int64_t dim)95 void _fill_indices(const TensorBase &indices, int64_t dim) {
96   auto ndim = indices.dim();
97   assert(0 <= dim && dim < ndim);
98   auto dim_size = indices.size(dim);
99   auto idx_dim = at::arange(0, dim_size, indices.options().dtype(at::kLong));
100   auto idx_dim_sizes = std::vector<int64_t>(ndim, 1);
101   auto idx_dim_strides = std::vector<int64_t>(ndim, 0);
102   idx_dim_sizes[dim] = dim_size;
103   idx_dim_strides[dim] = 1;
104   auto idx_dim_restrided = idx_dim.as_strided(idx_dim_sizes, idx_dim_strides);
105   OptionalTensorRef(indices)->copy_(idx_dim_restrided);
106 }
107 
108 namespace {
109 
110 /* Note from TH:
111    I cut and pasted (slightly adapted) the quicksort code from
112    Sedgewick's 1978 "Implementing Quicksort Programs" article
113    http://www.csie.ntu.edu.tw/~b93076/p847-sedgewick.pdf
114 
115    It is the state of the art existing implementation. The macros
116    are here to make as close a match as possible to the pseudocode of
117    Program 2 p.851
118 
119    Note that other partition schemes exist, and are typically presented
120    in textbook, but those are less efficient. See e.g.
121    http://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto
122 
123    Julien, November 12th 2013
124 */
125 template <typename scalar_t, typename Comp, typename Fn>
quick_select_template(TensorAccessor<scalar_t,1> arr,int64_t k,Comp gt_or_nan,Fn swap_fn)126 void quick_select_template(
127     TensorAccessor<scalar_t, 1> arr,
128     int64_t k,
129     Comp gt_or_nan,
130     Fn swap_fn) {
131   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
132   int64_t P, L, R, i, j;
133   scalar_t piv;
134   L = 0;
135   R = arr.size(0) - 1;
136 
137   do {
138     if (R <= L) // One element only
139       return;
140 
141     if (R == L + 1) { // Two elements only
142       if (gt_or_nan(arr[L], arr[R])) {
143         swap_fn(L, R);
144       }
145       return;
146     }
147 
148     // Use median of three for pivot choice
149     P = L + (R - L) / 2;
150     swap_fn(P, L + 1);
151     if (gt_or_nan(arr[L + 1], arr[R])) {
152       swap_fn(L + 1, R);
153     }
154     if (gt_or_nan(arr[L], arr[R])) {
155       swap_fn(L, R);
156     }
157     if (gt_or_nan(arr[L + 1], arr[L])) {
158       swap_fn(L + 1, L);
159     }
160 
161     i = L + 1;
162     j = R;
163     piv = arr[L];
164     do {
165       do
166         i++;
167       while (gt_or_nan(piv, arr[i]));
168       do
169         j--;
170       while (gt_or_nan(arr[j], piv));
171       if (j < i)
172         break;
173       swap_fn(i, j);
174     } while (true);
175     swap_fn(L, j);
176 
177     // Re-set active partition
178     if (j <= k)
179       L = i;
180     if (j >= k)
181       R = j - 1;
182   } while (true);
183 }
184 
185 namespace {
186 
get_quantile_interpolation_mode(const c10::string_view interpolation)187 QUANTILE_INTERPOLATION_MODE get_quantile_interpolation_mode(
188     const c10::string_view interpolation) {
189   if (interpolation == "linear") {
190     return QUANTILE_INTERPOLATION_MODE::LINEAR;
191   } else if (interpolation == "lower") {
192     return QUANTILE_INTERPOLATION_MODE::LOWER;
193   } else if (interpolation == "higher") {
194     return QUANTILE_INTERPOLATION_MODE::HIGHER;
195   } else if (interpolation == "midpoint") {
196     return QUANTILE_INTERPOLATION_MODE::MIDPOINT;
197   } else if (interpolation == "nearest") {
198     return QUANTILE_INTERPOLATION_MODE::NEAREST;
199   } else {
200     TORCH_CHECK(
201         false,
202         "quantile() interpolation must be one of linear, lower, higher, midpoint or nearest, but got ",
203         interpolation);
204   }
205 }
206 
quantile_checks(const Tensor & self,const Tensor & q)207 void quantile_checks(const Tensor& self, const Tensor& q) {
208   TORCH_CHECK(self.numel() > 0, "quantile() input tensor must be non-empty");
209   TORCH_CHECK(q.dim() <= 1, "quantile() q must be a scalar or 1D tensor");
210   TORCH_CHECK(
211       self.scalar_type() == kFloat || self.scalar_type() == kDouble,
212       "quantile() input tensor must be either float or double dtype");
213   TORCH_CHECK(
214       self.scalar_type() == q.scalar_type(),
215       "quantile() q tensor must be same dtype as the input tensor");
216   TORCH_CHECK(
217       self.device() == q.device(),
218       "quantile() q tensor must be on the same device as the input tensor");
219 }
220 
quantile_output_shape(const std::optional<int64_t> original_dim,const Tensor & self,const Tensor & q,const bool keepdim,int64_t wrapped_dim)221 std::vector<int64_t> quantile_output_shape(
222     const std::optional<int64_t> original_dim,
223     const Tensor& self,
224     const Tensor& q,
225     const bool keepdim,
226     int64_t wrapped_dim) {
227   // Compute output shape: q_size + reduced_size
228   std::vector<int64_t> out_shape;
229   if (original_dim && self.dim() > 0) {
230     out_shape = self.sizes().vec();
231     if (keepdim) {
232       out_shape[wrapped_dim] = 1;
233     } else {
234       out_shape.erase(out_shape.begin() + wrapped_dim);
235     }
236   } else if (keepdim) {
237     out_shape = std::vector<int64_t>(self.dim(), 1);
238   }
239   if (q.dim() > 0) {
240     out_shape.insert(out_shape.begin(), q.numel());
241   }
242 
243   return out_shape;
244 }
245 
quantile_compute(const Tensor & self,const Tensor & q,const std::optional<int64_t> orginal_dim,const bool keepdim,const QUANTILE_INTERPOLATION_MODE & interpolation,const bool ignore_nan,int64_t wrapped_dim,std::vector<int64_t> out_shape)246 Tensor quantile_compute(
247     const Tensor& self,
248     const Tensor& q,
249     const std::optional<int64_t> orginal_dim,
250     const bool keepdim,
251     const QUANTILE_INTERPOLATION_MODE& interpolation,
252     const bool ignore_nan,
253     int64_t wrapped_dim,
254     std::vector<int64_t> out_shape) {
255   // Checks that all q values are between 0 and 1, inclusive
256   // NOTE: this check is only performed when running on the CPU to avoid
257   // synchronizing an accelerator with the CPU
258   if (self.device().is_cpu()) {
259     auto all_q_in_range = q.ge(0).logical_and_(q.le(1)).all();
260     TORCH_CHECK(at::is_scalar_tensor_true(all_q_in_range),
261                 "quantile() q values must be in the range [0, 1]");
262   }
263 
264   // Flatten input if no dim provided else move dim to reduce as last dimension.
265   // Sort to efficiently query kth values.
266   Tensor sorted;
267   if (!orginal_dim) {
268     sorted = std::get<0>(self.flatten().sort());
269   } else if (wrapped_dim == self.dim() - 1) {
270     sorted = std::get<0>(self.sort());
271   } else {
272     sorted = std::get<0>(self.unsqueeze(-1).transpose(wrapped_dim, -1).sort());
273   }
274 
275   // Treat q as a 1D tensor for the following computations
276   if (q.dim() == 0) {
277     out_shape.insert(out_shape.begin(), q.numel());
278   }
279 
280   // View input as reduced_size + size of dim to reduce
281   std::vector<int64_t> in_shape(out_shape.size());
282   std::copy(out_shape.begin() + 1, out_shape.end(), in_shape.begin());
283   in_shape[in_shape.size() - 1] = sorted.size(-1);
284   sorted = sorted.view(in_shape);
285 
286   // Ensure converting from int64_t to double won't overflow
287   TORCH_CHECK(
288       sorted.size(-1) <= std::pow(2, 24),
289       "quantile() input tensor is too large");
290 
291   // Convert q in [0, 1] to ranks in [0, reduction_size)
292   Tensor ranks;
293   if (ignore_nan) {
294     // For nanquantile, compute ranks based on number of non-nan values.
295     // If all values are nan, set rank to 0 so the quantile computed is nan.
296     ranks = q * (sorted.isnan().logical_not_().sum(-1, true) - 1);
297     // For Composite Compliance,
298     // if `ranks` is `CCT` but it's tangent is a regular Tensor,
299     // then while computing jvp, we end calling `masked_fill_`
300     // on a regular Tensor with CCT args, so we call
301     // `masked_fill` instead.
302     if (isTensorSubclassLike(ranks) && ranks._fw_grad(/*level=*/0).defined()) {
303       ranks = ranks.masked_fill(ranks < 0, 0);
304     } else {
305       ranks.masked_fill_(ranks < 0, 0);
306     }
307   } else {
308     // For quantile, compute ranks based on reduction size. If there is nan
309     // set rank to last index so the quantile computed will be nan.
310     int64_t last_index = sorted.size(-1) - 1;
311     std::vector<Tensor> tl =
312         at::broadcast_tensors({q * last_index, sorted.isnan().any(-1, true)});
313     ranks = at::masked_fill(tl[0], tl[1], last_index);
314   }
315 
316   // adjust ranks based on the interpolation mode
317   if (interpolation == QUANTILE_INTERPOLATION_MODE::LOWER) {
318     ranks.floor_();
319   } else if (interpolation == QUANTILE_INTERPOLATION_MODE::HIGHER) {
320     ranks.ceil_();
321   } else if (interpolation == QUANTILE_INTERPOLATION_MODE::NEAREST) {
322     ranks.round_();
323   }
324 
325   Tensor ranks_below = ranks.toType(kLong);
326   Tensor values_below = sorted.gather(-1, ranks_below);
327 
328   // Actual interpolation is only needed for the liner and midpoint modes
329   if (interpolation == QUANTILE_INTERPOLATION_MODE::LINEAR ||
330       interpolation == QUANTILE_INTERPOLATION_MODE::MIDPOINT) {
331     // calculate weights for linear and midpoint
332     Tensor weights = interpolation == QUANTILE_INTERPOLATION_MODE::MIDPOINT
333         ? at::full_like(ranks, 0.5)
334         : ranks - ranks_below;
335 
336     // Interpolate to compute quantiles and store in values_below
337     Tensor ranks_above = ranks.ceil_().toType(kLong);
338     Tensor values_above = sorted.gather(-1, ranks_above);
339     // For Composite Compliance,
340     // if either `values_below`, `values_above` or `weights` are a CCT
341     // or tangents of `value_above` and `weights` are a CCT,
342     // but if the tangent of `value_below` is a regular Tensor,
343     // then while computing jvp, we will end-up copying a `CCT`,
344     // into regular Tensor. So we use out-of-place variant of `lerp`
345     auto is_primal_cct =
346         areAnyTensorSubclassLike({values_below, values_above, weights});
347     auto is_tangent_cct = areAnyTensorSubclassLike(
348         {values_above._fw_grad(/*level=*/0), weights._fw_grad(/*level=*/0)});
349     if ((is_primal_cct || is_tangent_cct) &&
350         values_below._fw_grad(/*level=*/0).defined() &&
351         !isTensorSubclassLike(values_below._fw_grad(/*level=*/0))) {
352       values_below = values_below.lerp(values_above, weights);
353     } else {
354       values_below.lerp_(values_above, weights);
355     }
356   }
357 
358   if (q.dim() == 0) {
359     // If q is scalar, remove last dim to match out shape
360     values_below.squeeze_(-1);
361   } else {
362     // Move quantiles to first dim to match out shape
363     values_below.unsqueeze_(0).transpose_(0, -1).squeeze_(-1);
364   }
365 
366   return values_below;
367 }
368 
369 } // namespace
370 
quantile_out_impl(Tensor & out,const Tensor & self,const Tensor & q,const std::optional<int64_t> original_dim,const bool keepdim,const QUANTILE_INTERPOLATION_MODE & interpolation,const bool ignore_nan)371 void quantile_out_impl(
372     Tensor& out,
373     const Tensor& self,
374     const Tensor& q,
375     const std::optional<int64_t> original_dim,
376     const bool keepdim,
377     const QUANTILE_INTERPOLATION_MODE& interpolation,
378     const bool ignore_nan) {
379   quantile_checks(self, q);
380   TORCH_CHECK(
381       self.scalar_type() == out.scalar_type(),
382       "quantile() out tensor must be same dtype as the input tensor");
383   TORCH_CHECK(
384       self.device() == out.device(),
385       "quantile() out tensor must be on the same device as the input tensor");
386 
387   int64_t wrapped_dim = at::maybe_wrap_dim(original_dim.value_or(0), self.dim());
388 
389   auto out_shape = quantile_output_shape(original_dim, self, q, keepdim, wrapped_dim);
390   resize_output(out, out_shape);
391 
392   auto quantile = quantile_compute(
393       self, q, original_dim, keepdim, interpolation, ignore_nan, wrapped_dim, std::move(out_shape));
394   out.copy_(quantile);
395 }
396 
quantile_impl(const Tensor & self,const Tensor & q,const std::optional<int64_t> original_dim,const bool keepdim,const QUANTILE_INTERPOLATION_MODE & interpolation,const bool ignore_nan)397 Tensor quantile_impl(
398     const Tensor& self,
399     const Tensor& q,
400     const std::optional<int64_t> original_dim,
401     const bool keepdim,
402     const QUANTILE_INTERPOLATION_MODE& interpolation,
403     const bool ignore_nan) {
404   quantile_checks(self, q);
405 
406   int64_t wrapped_dim = at::maybe_wrap_dim(original_dim.value_or(0), self.dim());
407 
408   auto out_shape = quantile_output_shape(original_dim, self, q, keepdim, wrapped_dim);
409 
410   return quantile_compute(
411       self, q, original_dim, keepdim, interpolation, ignore_nan, wrapped_dim, std::move(out_shape));
412 }
413 
kthvalue_out_impl_cpu(Tensor & values,Tensor & indices,const Tensor & self,int64_t k,int64_t dim_,bool keepdim)414 std::tuple<Tensor&, Tensor&> kthvalue_out_impl_cpu(
415     Tensor& values,
416     Tensor& indices,
417     const Tensor& self,
418     int64_t k,
419     int64_t dim_,
420     bool keepdim) {
421   int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
422   int64_t slicesize = self.dim() == 0 ? 1 : self.size(dim);
423   zero_numel_check_dims(self, dim, "kthvalue()");
424 
425   TORCH_CHECK(k >= 1 && k <= slicesize,
426               "kthvalue(): selected number k out of range for dimension ", dim);
427 
428   at::assert_no_overlap(self, values);
429 
430   _reduction_with_indices_allocate_or_resize_output(
431       values, indices, self, dim_, keepdim);
432   if (self.dim() == 0 && self.numel() == 1) {
433     values.copy_(self);
434     indices.zero_();
435     return std::forward_as_tuple(values, indices);
436   }
437   auto tmp_values = self.clone(at::MemoryFormat::Contiguous);
438   auto tmp_indices = at::empty(self.sizes(), self.options().dtype(kLong));
439 
440   auto tmp_values_stride = tmp_values.strides()[dim];
441   auto tmp_indices_stride = tmp_indices.strides()[dim];
442   auto sizes = self.sizes();
443 
444   TORCH_CHECK(indices.scalar_type() == kLong);
445 
446   auto iter = TensorIteratorConfig()
447     .check_all_same_dtype(false)
448     .resize_outputs(false)
449     .declare_static_shape(sizes, /*squash_dims=*/dim)
450     .add_output(tmp_values)
451     .add_output(tmp_indices)
452     .add_output(values)
453     .add_output(indices)
454     .build();
455 
456   AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "kthvalue_cpu", [&] {
457     auto loop = [&](char** data, const int64_t* strides, int64_t n) {
458       for (const auto i : c10::irange(n)) {
459         TensorAccessor<scalar_t, 1> tmp_values(
460             reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
461             &sizes[dim], &tmp_values_stride);
462         TensorAccessor<int64_t, 1> tmp_indices(
463             reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
464             &sizes[dim], &tmp_indices_stride);
465         auto mode_value = reinterpret_cast<scalar_t*>(data[2] + i * strides[2]);
466         auto mode_index = reinterpret_cast<int64_t*>(data[3] + i * strides[3]);
467 
468         for (const auto j : c10::irange(tmp_indices.size(0))) {
469           tmp_indices[j] = j;
470         }
471 
472         // we want NaN to be sorted as top for numpy compatibility
473         quick_select_template(
474           tmp_values,
475           k - 1,
476           [](scalar_t x, scalar_t y) -> bool {
477             return (
478               (_isnan<scalar_t>(x) && !_isnan<scalar_t>(y)) || (x > y));
479           },
480           [&](int64_t i, int64_t j) {
481             std::swap(tmp_values[i], tmp_values[j]);
482             std::swap(tmp_indices[i], tmp_indices[j]);
483           });
484         *mode_value = tmp_values[k - 1];
485         *mode_index = tmp_indices[k - 1];
486       }
487     };
488 
489     int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
490     iter.for_each(loop, /*grain_size=*/grain_size);
491   });
492 
493   if (!keepdim) {
494     values.squeeze_(dim);
495     indices.squeeze_(dim);
496   }
497   return std::forward_as_tuple(values, indices);
498 }
499 
500 // Computes both the median and its index along dimension dim of the input
median_with_indices_impl(Tensor & values,Tensor & indices,const Tensor & self,int64_t dim,bool keepdim,bool ignore_nan)501 std::tuple<Tensor&, Tensor&> median_with_indices_impl(
502     Tensor& values,
503     Tensor& indices,
504     const Tensor& self,
505     int64_t dim,
506     bool keepdim,
507     bool ignore_nan) {
508   dim = at::maybe_wrap_dim(dim, self.dim());
509 
510   int64_t size = self.dim() > 0 ? self.size(dim) : 1;
511   zero_numel_check_dims(self, dim, "median()");
512 
513   checkDeviceType("median", {values, indices}, self.device().type());
514   checkScalarType("median", {indices, "indices", 1}, kLong);
515   checkSameType("median", {values, "values", 0}, {self, "self", 2});
516 
517   std::vector<int64_t> out_shape = self.sizes().vec();
518   if (self.dim() > 0) {
519     if (keepdim) {
520       out_shape[dim] = 1;
521     } else {
522       out_shape.erase(out_shape.begin() + dim);
523     }
524   }
525 
526   resize_output(values, out_shape);
527   resize_output(indices, out_shape);
528 
529   // Ensure #dim is the same for all tensors required for dim_apply
530   Tensor in = self.dim() > 0 ? self : self.unsqueeze(0);
531   Tensor vals = keepdim && self.dim() > 0 ? values : values.unsqueeze(dim);
532   Tensor inds = keepdim && self.dim() > 0 ? indices : indices.unsqueeze(dim);
533 
534   // Make dim to reduce contiguous (stride=1)
535   if (in.stride(dim) > 1) {
536     in = in.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim).contiguous();
537     vals = vals.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim);
538     inds = inds.unsqueeze(-1).transpose_(dim, -1).squeeze_(dim);
539     dim = in.dim() - 1;
540   }
541 
542   auto sizes = in.sizes();
543   auto iter = TensorIteratorConfig()
544     .check_all_same_dtype(false)
545     .resize_outputs(false)
546     .declare_static_shape(sizes, /*squash_dims=*/dim)
547     .add_output(vals)
548     .add_output(inds)
549     .add_const_input(in)
550     .build();
551 
552   AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, in.scalar_type(), "median_out", [&] {
553     auto loop = [&](char** data, const int64_t* strides, int64_t n) {
554       for (const auto i : c10::irange(n)) {
555         auto valp = reinterpret_cast<scalar_t*>(data[0] + i * strides[0]);
556         auto indp = reinterpret_cast<int64_t*>(data[1] + i * strides[1]);
557         auto ip = reinterpret_cast<const scalar_t*>(data[2] + i * strides[2]);
558 
559         // For torch.median, search for NaN and return it if found
560         if (!ignore_nan) {
561           const scalar_t* nanp = std::find_if(ip, ip + size, _isnan<scalar_t>);
562           if (nanp != ip + size) {
563             *valp = *nanp;
564             *indp = nanp - ip;
565             continue;
566           }
567         }
568 
569         // Vector of indices for indirectly partitioning input around median
570         std::vector<int64_t> idx(size);
571         auto first = idx.begin();
572         auto last = idx.end();
573         std::iota(first, last, 0);
574 
575         // We partition the input around the median indirectly using the indices
576         // vector so that nth points to the index of the median in the unmodified
577         // input tensor.
578         auto nth = first;
579         if (!ignore_nan) {
580           // If we got here, there are no nan values
581           nth += (size - 1) / 2;
582           std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) {
583             return ip[i] < ip[j] || (ip[i] == ip[j] && i < j);
584           });
585         } else {
586           // For torch.nanmedian, compute median of non-nan values only
587           int64_t num_nan = std::count_if(ip, ip + size, _isnan<scalar_t>);
588           nth += (size - num_nan - 1) / 2;
589           std::nth_element(first, nth, last, [&ip](int64_t i, int64_t j) {
590             return ip[i] < ip[j] || (ip[i] == ip[j] && i < j) ||
591                 (_isnan(ip[j]) && !_isnan(ip[i]));
592           });
593         }
594 
595         *valp = ip[*nth];
596         *indp = *nth;
597       }
598     };
599     int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, sizes[dim]);
600     iter.for_each(loop, /*grain_size=*/grain_size);
601   });
602 
603   return std::forward_as_tuple(values, indices);
604 }
605 
606 // Computes the median of all values in the input
median_impl(const Tensor & self,bool ignore_nan)607 Tensor median_impl(const Tensor& self, bool ignore_nan) {
608   NoNamesGuard guard;
609   const int64_t size = self.numel();
610 
611   // Return nan for empty tensors
612   if (size <= 0) {
613     return at::full({}, std::numeric_limits<float>::quiet_NaN()).to(self.options());
614   }
615 
616   // Clone the input tensor so we can partition it around the median value
617   Tensor in = self.clone();
618   Tensor out = at::empty({}, self.options());
619 
620   AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, in.scalar_type(), "median_cpu", [&] {
621     scalar_t* op = out.data_ptr<scalar_t>();
622     scalar_t* first = in.data_ptr<scalar_t>();
623     scalar_t* last = first + size;
624 
625     // For torch.median, if there are nan values return nan
626     if (!ignore_nan && std::any_of(first, last, _isnan<scalar_t>)) {
627       *op = std::numeric_limits<scalar_t>::quiet_NaN();
628       return;
629     }
630 
631     scalar_t* median = first;
632     if (!ignore_nan) {
633       // If we got here, there are no nan values
634       median += (size - 1) / 2;
635       std::nth_element(first, median, last);
636     } else {
637       // For torch.nanmedian, compute median of non-nan values only
638       int64_t num_nan = std::count_if(first, last, _isnan<scalar_t>);
639       median += (size - num_nan - 1) / 2;
640       std::nth_element(first, median, last, [](scalar_t a, scalar_t b) {
641         return a < b || (_isnan(b) && !_isnan(a));
642       });
643     }
644 
645     *op = *median;
646   });
647 
648   return out;
649 }
650 
651 } // namespace
652 
quantile_out(const Tensor & self,const Tensor & q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation,Tensor & out)653 Tensor& quantile_out(
654     const Tensor& self,
655     const Tensor& q,
656     std::optional<int64_t> dim,
657     bool keepdim,
658     const c10::string_view interpolation,
659     Tensor& out) {
660   quantile_out_impl(
661       out,
662       self,
663       q,
664       dim,
665       keepdim,
666       get_quantile_interpolation_mode(interpolation),
667       /*ignore_nan=*/false);
668   return out;
669 }
670 
quantile_out(const Tensor & self,double q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation,Tensor & out)671 Tensor& quantile_out(
672     const Tensor& self,
673     double q,
674     std::optional<int64_t> dim,
675     bool keepdim,
676     const c10::string_view interpolation,
677     Tensor& out) {
678   TORCH_CHECK(
679       q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
680   return at::native::quantile_out(
681       self,
682       at::scalar_tensor(q, self.options()),
683       dim,
684       keepdim,
685       interpolation,
686       out);
687 }
688 
quantile(const Tensor & self,const Tensor & q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation)689 Tensor quantile(
690     const Tensor& self,
691     const Tensor& q,
692     std::optional<int64_t> dim,
693     bool keepdim,
694     const c10::string_view interpolation) {
695   return quantile_impl(
696       self,
697       q,
698       dim,
699       keepdim,
700       get_quantile_interpolation_mode(interpolation),
701       /*ignore_nan=*/false);
702 }
703 
quantile(const Tensor & self,double q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation)704 Tensor quantile(
705     const Tensor& self,
706     double q,
707     std::optional<int64_t> dim,
708     bool keepdim,
709     const c10::string_view interpolation) {
710   TORCH_CHECK(
711       q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
712   return at::native::quantile(
713       self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation);
714 }
715 
nanquantile_out(const Tensor & self,const Tensor & q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation,Tensor & out)716 Tensor& nanquantile_out(
717     const Tensor& self,
718     const Tensor& q,
719     std::optional<int64_t> dim,
720     bool keepdim,
721     const c10::string_view interpolation,
722     Tensor& out) {
723   quantile_out_impl(
724       out,
725       self,
726       q,
727       dim,
728       keepdim,
729       get_quantile_interpolation_mode(interpolation),
730       /*ignore_nan=*/true);
731   return out;
732 }
733 
nanquantile_out(const Tensor & self,double q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation,Tensor & out)734 Tensor& nanquantile_out(
735     const Tensor& self,
736     double q,
737     std::optional<int64_t> dim,
738     bool keepdim,
739     const c10::string_view interpolation,
740     Tensor& out) {
741   TORCH_CHECK(
742       q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
743   return at::native::nanquantile_out(
744       self,
745       at::scalar_tensor(q, self.options()),
746       dim,
747       keepdim,
748       interpolation,
749       out);
750 }
751 
nanquantile(const Tensor & self,const Tensor & q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation)752 Tensor nanquantile(
753     const Tensor& self,
754     const Tensor& q,
755     std::optional<int64_t> dim,
756     bool keepdim,
757     const c10::string_view interpolation) {
758   return quantile_impl(
759       self,
760       q,
761       dim,
762       keepdim,
763       get_quantile_interpolation_mode(interpolation),
764       /*ignore_nan=*/true);
765 }
766 
nanquantile(const Tensor & self,double q,std::optional<int64_t> dim,bool keepdim,const c10::string_view interpolation)767 Tensor nanquantile(
768     const Tensor& self,
769     double q,
770     std::optional<int64_t> dim,
771     bool keepdim,
772     const c10::string_view interpolation) {
773   TORCH_CHECK(
774       q >= 0 && q <= 1, "quantile() q must be in the range [0, 1] but got ", q);
775   return at::native::nanquantile(
776       self, at::scalar_tensor(q, self.options()), dim, keepdim, interpolation);
777 }
778 
kthvalue_out_cpu(const Tensor & self,int64_t k,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)779 std::tuple<Tensor&, Tensor&> kthvalue_out_cpu(
780     const Tensor& self,
781     int64_t k,
782     int64_t dim,
783     bool keepdim,
784     Tensor& values,
785     Tensor& indices) {
786   auto result = [&]() {
787     NoNamesGuard guard;
788     return kthvalue_out_impl_cpu(values, indices, self, k, dim, keepdim);
789   }();
790   namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
791   namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
792   return result;
793 }
794 
kthvalue_out(const Tensor & self,int64_t k,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)795 std::tuple<Tensor&, Tensor&> kthvalue_out(
796     const Tensor& self,
797     int64_t k,
798     Dimname dim,
799     bool keepdim,
800     Tensor& values,
801     Tensor& indices) {
802   return at::kthvalue_out(
803       values, indices, self, k, dimname_to_position(self, dim), keepdim);
804 }
805 
kthvalue(const Tensor & self,int64_t k,int64_t dim,bool keepdim)806 std::tuple<Tensor, Tensor> kthvalue(
807     const Tensor& self,
808     int64_t k,
809     int64_t dim,
810     bool keepdim) {
811   Tensor values = at::empty({0}, self.options());
812   Tensor indices = at::empty({0}, self.options().dtype(kLong));
813   at::kthvalue_out(values, indices, self, k, dim, keepdim);
814   return std::make_tuple(values, indices);
815 }
816 
kthvalue(const Tensor & self,int64_t k,Dimname dim,bool keepdim)817 std::tuple<Tensor, Tensor> kthvalue(
818     const Tensor& self,
819     int64_t k,
820     Dimname dim,
821     bool keepdim) {
822   return at::kthvalue(self, k, dimname_to_position(self, dim), keepdim);
823 }
824 
TORCH_IMPL_FUNC(topk_out_cpu)825 TORCH_IMPL_FUNC(topk_out_cpu)
826    (const Tensor& self,
827     int64_t k,
828     int64_t dim_,
829     bool largest,
830     bool sorted,
831     const Tensor& values,
832     const Tensor& indices) {
833   int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
834   TORCH_CHECK(
835       k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
836       "selected index k out of range");
837 
838   if (self.dim() == 0 && self.numel() == 1) {
839     values.copy_(self);
840     indices.zero_();
841   } else {
842     topk_stub(kCPU, values, indices, self, k, dim, largest, sorted);
843   }
844 }
845 
median_out_cpu(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)846 std::tuple<Tensor&, Tensor&> median_out_cpu(
847     const Tensor& self,
848     int64_t dim,
849     bool keepdim,
850     Tensor& values,
851     Tensor& indices) {
852   auto result = [&]() {
853     NoNamesGuard guard;
854     return median_with_indices_impl(
855         values, indices, self, dim, keepdim, /*ignore_nan=*/false);
856   }();
857   namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
858   namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
859   return result;
860 }
861 
median_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)862 std::tuple<Tensor&, Tensor&> median_out(
863     const Tensor& self,
864     Dimname dim,
865     bool keepdim,
866     Tensor& values,
867     Tensor& indices) {
868   return at::median_out(
869       values, indices, self, dimname_to_position(self, dim), keepdim);
870 }
871 
median(const Tensor & self,int64_t dim,bool keepdim)872 std::tuple<Tensor, Tensor> median(
873     const Tensor& self,
874     int64_t dim,
875     bool keepdim) {
876   Tensor values = at::empty({0}, self.options());
877   Tensor indices = at::empty({0}, self.options().dtype(kLong));
878   at::median_out(values, indices, self, dim, keepdim);
879   return std::make_tuple(values, indices);
880 }
881 
median(const Tensor & self,Dimname dim,bool keepdim)882 std::tuple<Tensor, Tensor> median(
883     const Tensor& self,
884     Dimname dim,
885     bool keepdim) {
886   return at::median(self, dimname_to_position(self, dim), keepdim);
887 }
888 
median_cpu(const Tensor & self)889 Tensor median_cpu(const Tensor& self) {
890   return median_impl(self, /*ignore_nan=*/false);
891 }
892 
nanmedian_out_cpu(const Tensor & self,int64_t dim,bool keepdim,Tensor & values,Tensor & indices)893 std::tuple<Tensor&, Tensor&> nanmedian_out_cpu(
894     const Tensor& self,
895     int64_t dim,
896     bool keepdim,
897     Tensor& values,
898     Tensor& indices) {
899   auto result = [&]() {
900     NoNamesGuard guard;
901     return median_with_indices_impl(
902         values, indices, self, dim, keepdim, /*ignore_nan=*/true);
903   }();
904   namedinference::propagate_names_for_reduction(values, self, dim, keepdim);
905   namedinference::propagate_names_for_reduction(indices, self, dim, keepdim);
906   return result;
907 }
908 
nanmedian_out(const Tensor & self,Dimname dim,bool keepdim,Tensor & values,Tensor & indices)909 std::tuple<Tensor&, Tensor&> nanmedian_out(
910     const Tensor& self,
911     Dimname dim,
912     bool keepdim,
913     Tensor& values,
914     Tensor& indices) {
915   return at::nanmedian_out(
916       values, indices, self, dimname_to_position(self, dim), keepdim);
917 }
918 
nanmedian(const Tensor & self,int64_t dim,bool keepdim)919 std::tuple<Tensor, Tensor> nanmedian(
920     const Tensor& self,
921     int64_t dim,
922     bool keepdim) {
923   Tensor values = at::empty({0}, self.options());
924   Tensor indices = at::empty({0}, self.options().dtype(kLong));
925   at::nanmedian_out(values, indices, self, dim, keepdim);
926   return std::make_tuple(values, indices);
927 }
928 
nanmedian(const Tensor & self,Dimname dim,bool keepdim)929 std::tuple<Tensor, Tensor> nanmedian(
930     const Tensor& self,
931     Dimname dim,
932     bool keepdim) {
933   return at::nanmedian(self, dimname_to_position(self, dim), keepdim);
934 }
935 
nanmedian_cpu(const Tensor & self)936 Tensor nanmedian_cpu(const Tensor& self) {
937   return median_impl(self, /*ignore_nan=*/true);
938 }
939 
TORCH_IMPL_FUNC(sort_stable_out)940 TORCH_IMPL_FUNC(sort_stable_out)
941 (const Tensor& self,
942  std::optional<bool> stable,
943  int64_t dim,
944  bool descending,
945  const Tensor& values,
946  const Tensor& indices) {
947   values.copy_(self);
948   // check if self is scalar
949   if (self.dim() == 0 && self.numel() == 1) {
950     indices.zero_();
951   } else {
952     dim = maybe_wrap_dim(dim, self.dim());
953     sort_stub(self.device().type(), self, values, indices, dim, descending, stable.value_or(false));
954   }
955 }
956 
sort_out(const Tensor & self,int64_t dim,bool descending,Tensor & values,Tensor & indices)957 std::tuple<Tensor&, Tensor&> sort_out(
958     const Tensor& self,
959     int64_t dim,
960     bool descending,
961     Tensor& values,
962     Tensor& indices) {
963   return at::sort_out(values, indices, self, false, dim, descending);
964 }
965 
sort(const Tensor & self,int64_t dim,bool descending)966 std::tuple<Tensor, Tensor> sort(
967     const Tensor& self,
968     int64_t dim,
969     bool descending) {
970   return at::sort(self, false, dim, descending);
971 }
972 
msort_out(const Tensor & self,Tensor & values)973 Tensor& msort_out(const Tensor& self, Tensor& values) {
974   Tensor indices = at::empty({0}, self.options().dtype(kLong));
975   at::sort_out(values, indices, self, 0, false);
976   return values;
977 }
978 
msort(const Tensor & self)979 Tensor msort(const Tensor& self) {
980   return std::get<0>(at::sort(self, 0, false));
981 }
982 
argsort(const Tensor & self,int64_t dim,bool descending)983 Tensor argsort(const Tensor & self, int64_t dim, bool descending) {
984   return std::get<1>(at::sort(self, dim, descending));
985 }
986 
argsort(const Tensor & self,bool stable,int64_t dim,bool descending)987 Tensor argsort(const Tensor & self, bool stable, int64_t dim, bool descending) {
988   return std::get<1>(at::sort(self, stable, dim, descending));
989 }
990 
argsort_out(const Tensor & self,bool stable,int64_t dim,bool descending,Tensor & out)991 Tensor& argsort_out(const Tensor & self, bool stable, int64_t dim, bool descending, Tensor& out) {
992   auto values = at::empty({0}, self.options());
993   at::sort_outf(self, stable, dim, descending, values, out);
994   return out;
995 }
996 
997 
998 } // namespace at::native
999