1 #pragma once
2 #include <ATen/core/TensorAccessor.h>
3 #include <ATen/NumericUtils.h>
4
5 namespace at::native {
6
7 #ifdef CPU_CAPABILITY
8 inline namespace CPU_CAPABILITY {
9 #else
10 inline namespace DEFAULT {
11 #endif
12
13 // Core topk loop, shared between CPU and QuantizedCPU
14 template <typename scalar_t, typename accscalar_t>
topk_impl_loop(const int64_t mode_values_stride,const int64_t mode_indices_stride,const int64_t tmp_values_stride,const int64_t k,const int64_t dim_size,const bool largest,const bool sorted,char ** data,const int64_t * strides,const int64_t n)15 void topk_impl_loop(
16 const int64_t mode_values_stride,
17 const int64_t mode_indices_stride,
18 const int64_t tmp_values_stride,
19 const int64_t k,
20 const int64_t dim_size,
21 const bool largest,
22 const bool sorted,
23 char** data, const int64_t* strides, const int64_t n) {
24
25 // If k is zero, then output values and indices are empty tensors
26 // So iterating over other dims is pointless
27 if (k == 0) {
28 return;
29 }
30 using elem_t = std::pair<accscalar_t, int64_t>;
31 std::vector<elem_t> queue(dim_size);
32 for (const auto i : c10::irange(n)) {
33 TensorAccessor<scalar_t, 1> mode_values(
34 reinterpret_cast<scalar_t*>(data[0] + i * strides[0]),
35 &k, &mode_values_stride);
36 TensorAccessor<int64_t, 1> mode_indices(
37 reinterpret_cast<int64_t*>(data[1] + i * strides[1]),
38 &k, &mode_indices_stride);
39 TensorAccessor<const scalar_t, 1> tmp_values(
40 reinterpret_cast<scalar_t*>(data[2] + i * strides[2]),
41 &dim_size, &tmp_values_stride);
42
43 auto n_2 = dim_size;
44 auto use_partial_sort = k * 64 <= n_2;
45
46 for (const auto j : c10::irange(n_2)) {
47 queue[j].first = tmp_values[j];
48 queue[j].second = j;
49 }
50
51 // we want nan to be sorted as top for numpy compatibility
52 if (use_partial_sort) {
53 if (largest) {
54 std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
55 [](const elem_t& x, const elem_t& y) -> bool {
56 return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
57 });
58 } else {
59 std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
60 [](const elem_t& x, const elem_t& y) -> bool {
61 return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
62 });
63 }
64 } else {
65 if (largest) {
66 std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
67 [](const elem_t& x, const elem_t& y) -> bool {
68 return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
69 });
70 if (sorted) {
71 std::sort(queue.begin(), queue.begin() + k - 1,
72 [](const elem_t& x, const elem_t& y) -> bool {
73 return ((_isnan<accscalar_t>(x.first) && !_isnan<accscalar_t>(y.first)) || (x.first > y.first));
74 });
75 }
76 } else {
77 std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
78 [](const elem_t& x, const elem_t& y) -> bool {
79 return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
80 });
81 if (sorted) {
82 std::sort(queue.begin(), queue.begin() + k -1,
83 [](const elem_t& x, const elem_t& y) -> bool {
84 return ((!_isnan<accscalar_t>(x.first) && _isnan<accscalar_t>(y.first)) || (x.first < y.first));
85 });
86 }
87 }
88 }
89
90 for (const auto j : c10::irange(k)) {
91 mode_values[j] = queue[j].first;
92 mode_indices[j] = queue[j].second;
93 }
94 }
95 }
96
97 } // namespace CPU_CAPABILITY
98 } // namespace at::native
99