xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TopKImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/TensorAccessor.h>
3 #include <ATen/NumericUtils.h>
4 
5 namespace at::native {
6 
7 #ifdef CPU_CAPABILITY
8 inline namespace CPU_CAPABILITY {
9 #else
10 inline namespace DEFAULT {
11 #endif
12 
13 // Core topk loop, shared between CPU and QuantizedCPU
14 template <typename scalar_t, typename accscalar_t>
topk_impl_loop(const int64_t mode_values_stride,const int64_t mode_indices_stride,const int64_t tmp_values_stride,const int64_t k,const int64_t dim_size,const bool largest,const bool sorted,char ** data,const int64_t * strides,const int64_t n)15 void topk_impl_loop(
16     const int64_t mode_values_stride,
17     const int64_t mode_indices_stride,
18     const int64_t tmp_values_stride,
19     const int64_t k,
20     const int64_t dim_size,
21     const bool largest,
22     const bool sorted,
23     char** data, const int64_t* strides, const int64_t n) {
24 
25   // If k is zero, then output values and indices are empty tensors
26   // So iterating over other dims is pointless
27   if (k == 0) {
28     return;
29   }
30   using elem_t = std::pair<accscalar_t, int64_t>;
31   std::vector<elem_t> queue(dim_size);
32   for (const auto i : c10::irange(n)) {
33     TensorAccessor<scalar_t, 1> mode_values(
34         reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
35         &k, &mode_values_stride);
36     TensorAccessor<int64_t, 1> mode_indices(
37         reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
38         &k, &mode_indices_stride);
39     TensorAccessor<const scalar_t, 1> tmp_values(
40         reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
41         &dim_size, &tmp_values_stride);
42 
43     auto n_2 = dim_size;
44     auto use_partial_sort = k * 64 <= n_2;
45 
46     for (const auto j : c10::irange(n_2)) {
47       queue[j].first = tmp_values[j];
48       queue[j].second = j;
49     }
50 
51     // we want nan to be sorted as top for numpy compatibility
52     if (use_partial_sort) {
53       if (largest) {
54         std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
55           [](const elem_t& x, const elem_t& y) -> bool {
56             return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
57           });
58       } else {
59         std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
60           [](const elem_t& x, const elem_t& y) -> bool {
61             return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
62           });
63       }
64     } else {
65       if (largest) {
66         std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
67           [](const elem_t& x, const elem_t& y) -> bool {
68             return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
69           });
70         if (sorted) {
71           std::sort(queue.begin(), queue.begin() + k - 1,
72             [](const elem_t& x, const elem_t& y) -> bool {
73               return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
74             });
75         }
76       } else {
77         std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
78           [](const elem_t& x, const elem_t& y) -> bool {
79             return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
80           });
81         if (sorted) {
82           std::sort(queue.begin(), queue.begin() + k -1,
83             [](const elem_t& x, const elem_t& y) -> bool {
84               return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
85             });
86         }
87       }
88     }
89 
90     for (const auto j : c10::irange(k)) {
91       mode_values[j] = queue[j].first;
92       mode_indices[j] = queue[j].second;
93     }
94   }
95 }
96 
97 } // namespace CPU_CAPABILITY
98 } // namespace at::native
99