xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_topk.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 #include <tuple>
11 
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 
14 namespace torch {
15 namespace executor {
16 namespace native {
17 namespace {
18 
check_topk_args(const Tensor & in,int64_t k,int64_t dim,Tensor & values,Tensor & indices)19 bool check_topk_args(
20     const Tensor& in,
21     int64_t k,
22     int64_t dim,
23     Tensor& values,
24     Tensor& indices) {
25   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, values));
26   ET_LOG_AND_RETURN_IF_FALSE(indices.scalar_type() == ScalarType::Long);
27   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
28   if (dim < 0) {
29     dim += nonzero_dim(in);
30   }
31   ET_LOG_MSG_AND_RETURN_IF_FALSE(
32       k >= 0 && k <= nonempty_size(in, dim), "selected index k out of range");
33   return true;
34 }
35 
get_topk_target_size(const Tensor & in,int64_t k,int64_t dim,Tensor::SizesType * target_size,size_t * target_dim)36 bool get_topk_target_size(
37     const Tensor& in,
38     int64_t k,
39     int64_t dim,
40     Tensor::SizesType* target_size,
41     size_t* target_dim) {
42   *target_dim = in.dim();
43   for (size_t i = 0; i < *target_dim; ++i) {
44     if (i == dim) {
45       target_size[i] = k;
46     } else {
47       target_size[i] = in.size(i);
48     }
49   }
50   return true;
51 }
52 
53 template <typename CTYPE, typename elem_t = std::pair<CTYPE, int64_t>>
perform_topk(const Tensor & in,int64_t k,int64_t dim,bool largest,bool sorted,Tensor & values,Tensor & indices,elem_t * queue)54 void perform_topk(
55     const Tensor& in,
56     int64_t k,
57     int64_t dim,
58     bool largest,
59     bool sorted,
60     Tensor& values,
61     Tensor& indices,
62     elem_t* queue) {
63   const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
64   CTYPE* values_data = values.mutable_data_ptr<CTYPE>();
65   long* indices_data = indices.mutable_data_ptr<long>();
66 
67   if (in.dim() == 0) {
68     values_data[0] = in_data[0];
69     indices_data[0] = 0;
70     return;
71   }
72 
73   if (k == 0) {
74     return;
75   }
76 
77   const size_t outer_size = getLeadingDims(in, dim);
78 
79   const size_t dim_size = in.size(dim);
80   const size_t dim_stride = in.strides()[dim];
81 
82   const size_t outer_stride_in = dim_size * dim_stride;
83   const size_t outer_stride_out = k * dim_stride;
84 
85   bool use_partial_sort = k * 64 <= dim_size;
86 
87   // Loop through all outer dimensions
88   for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
89     size_t outer_in = outer_idx * outer_stride_in;
90     size_t outer_out = outer_idx * outer_stride_out;
91     // Loop through all inner dimensions
92     for (size_t inner_idx = 0; inner_idx < dim_stride; ++inner_idx) {
93       size_t base_in = outer_in + inner_idx;
94       size_t base_out = outer_out + inner_idx;
95 
96       // Populate the queue with the values from the input tensor
97       for (size_t i = 0; i < dim_size; ++i) {
98         size_t in_ix = base_in + i * dim_stride;
99         queue[i].first = in_data[in_ix];
100         queue[i].second = i;
101       }
102 
103       // Perform topk on the queue
104       if (use_partial_sort) {
105         if (largest) {
106           std::partial_sort(
107               queue,
108               queue + k,
109               queue + dim_size,
110               [](const elem_t& x, const elem_t& y) -> bool {
111                 return (
112                     (std::isnan(x.first) && !std::isnan(y.first)) ||
113                     (x.first > y.first));
114               });
115         } else {
116           std::partial_sort(
117               queue,
118               queue + k,
119               queue + dim_size,
120               [](const elem_t& x, const elem_t& y) -> bool {
121                 return (
122                     (!std::isnan(x.first) && std::isnan(y.first)) ||
123                     (x.first < y.first));
124               });
125         }
126       } else {
127         if (largest) {
128           std::nth_element(
129               queue,
130               queue + k - 1,
131               queue + dim_size,
132               [](const elem_t& x, const elem_t& y) -> bool {
133                 return (
134                     (std::isnan(x.first) && !std::isnan(y.first)) ||
135                     (x.first > y.first));
136               });
137           if (sorted) {
138             std::sort(
139                 queue,
140                 queue + k - 1,
141                 [](const elem_t& x, const elem_t& y) -> bool {
142                   return (
143                       (std::isnan(x.first) && !std::isnan(y.first)) ||
144                       (x.first > y.first));
145                 });
146           }
147         } else {
148           std::nth_element(
149               queue,
150               queue + k - 1,
151               queue + dim_size,
152               [](const elem_t& x, const elem_t& y) -> bool {
153                 return (
154                     (!std::isnan(x.first) && std::isnan(y.first)) ||
155                     (x.first < y.first));
156               });
157           if (sorted) {
158             std::sort(
159                 queue,
160                 queue + k - 1,
161                 [](const elem_t& x, const elem_t& y) -> bool {
162                   return (
163                       (!std::isnan(x.first) && std::isnan(y.first)) ||
164                       (x.first < y.first));
165                 });
166           }
167         }
168       }
169 
170       // Write the topk values and indices to the output tensors
171       for (size_t i = 0; i < k; ++i) {
172         size_t out_ix = base_out + i * dim_stride;
173 
174         values_data[out_ix] = queue[i].first;
175         indices_data[out_ix] = queue[i].second;
176       }
177     }
178   }
179 }
180 
allocate_temp_memory(KernelRuntimeContext & ctx,size_t size)181 void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) {
182   Result<void*> temp_mem_res = ctx.allocate_temp(size);
183   return temp_mem_res.ok() ? temp_mem_res.get() : nullptr;
184 }
185 
186 } // namespace
187 
topk_values(KernelRuntimeContext & ctx,const Tensor & in,int64_t k,int64_t dim,bool largest,bool sorted,Tensor & values,Tensor & indices)188 std::tuple<Tensor&, Tensor&> topk_values(
189     KernelRuntimeContext& ctx,
190     const Tensor& in,
191     int64_t k,
192     int64_t dim,
193     bool largest,
194     bool sorted,
195     Tensor& values,
196     Tensor& indices) {
197   auto out = std::tuple<Tensor&, Tensor&>({values, indices});
198 
199   ET_KERNEL_CHECK(
200       ctx, check_topk_args(in, k, dim, values, indices), InvalidArgument, out);
201 
202   if (dim < 0) {
203     dim += nonzero_dim(in);
204   }
205 
206   // @lint-ignore CLANGTIDY facebook-hte-CArray
207   Tensor::SizesType target_size[kTensorDimensionLimit];
208   size_t target_dim = 0;
209   get_topk_target_size(in, k, dim, target_size, &target_dim);
210 
211   ET_KERNEL_CHECK(
212       ctx,
213       resize_tensor(values, {target_size, target_dim}) == Error::Ok,
214       InvalidArgument,
215       out);
216 
217   ET_KERNEL_CHECK(
218       ctx,
219       resize_tensor(indices, {target_size, target_dim}) == Error::Ok,
220       InvalidArgument,
221       out);
222 
223   constexpr auto name = "topk.values";
224 
225   if (in.numel() == 0 || (k == 0 && in.dim() > 0)) {
226     return out;
227   }
228 
229   bool temp_mem_allocated = false;
230 
231   ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
232     using elem_t = std::pair<CTYPE, int64_t>;
233     size_t temp_mem_size = nonempty_size(in, dim) * sizeof(elem_t);
234 
235     elem_t* queue = (elem_t*)allocate_temp_memory(ctx, temp_mem_size);
236     if (queue == nullptr) {
237       return;
238     }
239     temp_mem_allocated = true;
240 
241     perform_topk<CTYPE>(in, k, dim, largest, sorted, values, indices, queue);
242   });
243 
244   ET_KERNEL_CHECK(ctx, temp_mem_allocated, MemoryAllocationFailed, out);
245 
246   return out;
247 }
248 
249 } // namespace native
250 } // namespace executor
251 } // namespace torch
252