1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10 #include <tuple>
11
12 #include <executorch/runtime/kernel/kernel_includes.h>
13
14 namespace torch {
15 namespace executor {
16 namespace native {
17 namespace {
18
check_topk_args(const Tensor & in,int64_t k,int64_t dim,Tensor & values,Tensor & indices)19 bool check_topk_args(
20 const Tensor& in,
21 int64_t k,
22 int64_t dim,
23 Tensor& values,
24 Tensor& indices) {
25 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, values));
26 ET_LOG_AND_RETURN_IF_FALSE(indices.scalar_type() == ScalarType::Long);
27 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
28 if (dim < 0) {
29 dim += nonzero_dim(in);
30 }
31 ET_LOG_MSG_AND_RETURN_IF_FALSE(
32 k >= 0 && k <= nonempty_size(in, dim), "selected index k out of range");
33 return true;
34 }
35
get_topk_target_size(const Tensor & in,int64_t k,int64_t dim,Tensor::SizesType * target_size,size_t * target_dim)36 bool get_topk_target_size(
37 const Tensor& in,
38 int64_t k,
39 int64_t dim,
40 Tensor::SizesType* target_size,
41 size_t* target_dim) {
42 *target_dim = in.dim();
43 for (size_t i = 0; i < *target_dim; ++i) {
44 if (i == dim) {
45 target_size[i] = k;
46 } else {
47 target_size[i] = in.size(i);
48 }
49 }
50 return true;
51 }
52
53 template <typename CTYPE, typename elem_t = std::pair<CTYPE, int64_t>>
perform_topk(const Tensor & in,int64_t k,int64_t dim,bool largest,bool sorted,Tensor & values,Tensor & indices,elem_t * queue)54 void perform_topk(
55 const Tensor& in,
56 int64_t k,
57 int64_t dim,
58 bool largest,
59 bool sorted,
60 Tensor& values,
61 Tensor& indices,
62 elem_t* queue) {
63 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
64 CTYPE* values_data = values.mutable_data_ptr<CTYPE>();
65 long* indices_data = indices.mutable_data_ptr<long>();
66
67 if (in.dim() == 0) {
68 values_data[0] = in_data[0];
69 indices_data[0] = 0;
70 return;
71 }
72
73 if (k == 0) {
74 return;
75 }
76
77 const size_t outer_size = getLeadingDims(in, dim);
78
79 const size_t dim_size = in.size(dim);
80 const size_t dim_stride = in.strides()[dim];
81
82 const size_t outer_stride_in = dim_size * dim_stride;
83 const size_t outer_stride_out = k * dim_stride;
84
85 bool use_partial_sort = k * 64 <= dim_size;
86
87 // Loop through all outer dimensions
88 for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
89 size_t outer_in = outer_idx * outer_stride_in;
90 size_t outer_out = outer_idx * outer_stride_out;
91 // Loop through all inner dimensions
92 for (size_t inner_idx = 0; inner_idx < dim_stride; ++inner_idx) {
93 size_t base_in = outer_in + inner_idx;
94 size_t base_out = outer_out + inner_idx;
95
96 // Populate the queue with the values from the input tensor
97 for (size_t i = 0; i < dim_size; ++i) {
98 size_t in_ix = base_in + i * dim_stride;
99 queue[i].first = in_data[in_ix];
100 queue[i].second = i;
101 }
102
103 // Perform topk on the queue
104 if (use_partial_sort) {
105 if (largest) {
106 std::partial_sort(
107 queue,
108 queue + k,
109 queue + dim_size,
110 [](const elem_t& x, const elem_t& y) -> bool {
111 return (
112 (std::isnan(x.first) && !std::isnan(y.first)) ||
113 (x.first > y.first));
114 });
115 } else {
116 std::partial_sort(
117 queue,
118 queue + k,
119 queue + dim_size,
120 [](const elem_t& x, const elem_t& y) -> bool {
121 return (
122 (!std::isnan(x.first) && std::isnan(y.first)) ||
123 (x.first < y.first));
124 });
125 }
126 } else {
127 if (largest) {
128 std::nth_element(
129 queue,
130 queue + k - 1,
131 queue + dim_size,
132 [](const elem_t& x, const elem_t& y) -> bool {
133 return (
134 (std::isnan(x.first) && !std::isnan(y.first)) ||
135 (x.first > y.first));
136 });
137 if (sorted) {
138 std::sort(
139 queue,
140 queue + k - 1,
141 [](const elem_t& x, const elem_t& y) -> bool {
142 return (
143 (std::isnan(x.first) && !std::isnan(y.first)) ||
144 (x.first > y.first));
145 });
146 }
147 } else {
148 std::nth_element(
149 queue,
150 queue + k - 1,
151 queue + dim_size,
152 [](const elem_t& x, const elem_t& y) -> bool {
153 return (
154 (!std::isnan(x.first) && std::isnan(y.first)) ||
155 (x.first < y.first));
156 });
157 if (sorted) {
158 std::sort(
159 queue,
160 queue + k - 1,
161 [](const elem_t& x, const elem_t& y) -> bool {
162 return (
163 (!std::isnan(x.first) && std::isnan(y.first)) ||
164 (x.first < y.first));
165 });
166 }
167 }
168 }
169
170 // Write the topk values and indices to the output tensors
171 for (size_t i = 0; i < k; ++i) {
172 size_t out_ix = base_out + i * dim_stride;
173
174 values_data[out_ix] = queue[i].first;
175 indices_data[out_ix] = queue[i].second;
176 }
177 }
178 }
179 }
180
allocate_temp_memory(KernelRuntimeContext & ctx,size_t size)181 void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) {
182 Result<void*> temp_mem_res = ctx.allocate_temp(size);
183 return temp_mem_res.ok() ? temp_mem_res.get() : nullptr;
184 }
185
186 } // namespace
187
topk_values(KernelRuntimeContext & ctx,const Tensor & in,int64_t k,int64_t dim,bool largest,bool sorted,Tensor & values,Tensor & indices)188 std::tuple<Tensor&, Tensor&> topk_values(
189 KernelRuntimeContext& ctx,
190 const Tensor& in,
191 int64_t k,
192 int64_t dim,
193 bool largest,
194 bool sorted,
195 Tensor& values,
196 Tensor& indices) {
197 auto out = std::tuple<Tensor&, Tensor&>({values, indices});
198
199 ET_KERNEL_CHECK(
200 ctx, check_topk_args(in, k, dim, values, indices), InvalidArgument, out);
201
202 if (dim < 0) {
203 dim += nonzero_dim(in);
204 }
205
206 // @lint-ignore CLANGTIDY facebook-hte-CArray
207 Tensor::SizesType target_size[kTensorDimensionLimit];
208 size_t target_dim = 0;
209 get_topk_target_size(in, k, dim, target_size, &target_dim);
210
211 ET_KERNEL_CHECK(
212 ctx,
213 resize_tensor(values, {target_size, target_dim}) == Error::Ok,
214 InvalidArgument,
215 out);
216
217 ET_KERNEL_CHECK(
218 ctx,
219 resize_tensor(indices, {target_size, target_dim}) == Error::Ok,
220 InvalidArgument,
221 out);
222
223 constexpr auto name = "topk.values";
224
225 if (in.numel() == 0 || (k == 0 && in.dim() > 0)) {
226 return out;
227 }
228
229 bool temp_mem_allocated = false;
230
231 ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
232 using elem_t = std::pair<CTYPE, int64_t>;
233 size_t temp_mem_size = nonempty_size(in, dim) * sizeof(elem_t);
234
235 elem_t* queue = (elem_t*)allocate_temp_memory(ctx, temp_mem_size);
236 if (queue == nullptr) {
237 return;
238 }
239 temp_mem_allocated = true;
240
241 perform_topk<CTYPE>(in, k, dim, largest, sorted, values, indices, queue);
242 });
243
244 ET_KERNEL_CHECK(ctx, temp_mem_allocated, MemoryAllocationFailed, out);
245
246 return out;
247 }
248
249 } // namespace native
250 } // namespace executor
251 } // namespace torch
252