xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_argmin.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 #include <tuple>
11 
12 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 #include <executorch/runtime/platform/assert.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using exec_aten::optional;
21 using exec_aten::Tensor;
22 
argmin_out(KernelRuntimeContext & ctx,const Tensor & in,optional<int64_t> dim,bool keepdim,Tensor & out)23 Tensor& argmin_out(
24     KernelRuntimeContext& ctx,
25     const Tensor& in,
26     optional<int64_t> dim,
27     bool keepdim,
28     Tensor& out) {
29   (void)ctx;
30 
31   ET_KERNEL_CHECK(
32       ctx,
33       check_argmin_argmax_args(in, dim, keepdim, out),
34       InvalidArgument,
35       out);
36 
37   ET_KERNEL_CHECK(
38       ctx,
39       resize_reduction_out(in, dim, keepdim, out) == Error::Ok,
40       InvalidArgument,
41       out);
42 
43   ET_KERNEL_CHECK(
44       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45 
46   ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
47     long* out_data = out.mutable_data_ptr<long>();
48 
49     for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
50       std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
51           [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
52             if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) {
53               acc_val = v;
54               acc_ix = ix;
55             }
56             return std::tuple<CTYPE, long>{acc_val, acc_ix};
57           },
58           in,
59           dim,
60           out_ix);
61       out_data[out_ix] = std::get<1>(acc);
62     }
63   });
64 
65   return out;
66 }
67 
68 } // namespace native
69 } // namespace executor
70 } // namespace torch
71