xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_any.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace torch {
13 namespace executor {
14 namespace native {
15 
16 using Tensor = exec_aten::Tensor;
17 using ScalarType = exec_aten::ScalarType;
18 
any_all_out(KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)19 Tensor& any_all_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
20   (void)ctx;
21 
22   ET_KERNEL_CHECK(
23       ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
24 
25   ET_KERNEL_CHECK(
26       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
27 
28   ScalarType in_type = in.scalar_type();
29   ScalarType out_type = out.scalar_type();
30   constexpr auto name = "any.all_out";
31 
32   ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
33     ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
34       const auto data_in = in.const_data_ptr<CTYPE_IN>();
35       auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
36       data_out[0] = static_cast<CTYPE_OUT>(false);
37       for (auto i = 0; i < in.numel(); ++i) {
38         if (static_cast<bool>(data_in[i])) {
39           data_out[0] = static_cast<CTYPE_OUT>(true);
40           break;
41         }
42       }
43     });
44   });
45 
46   return out;
47 }
48 
any_dims_out(KernelRuntimeContext & ctx,const Tensor & in,optional<ArrayRef<int64_t>> dim_list,bool keepdim,Tensor & out)49 Tensor& any_dims_out(
50     KernelRuntimeContext& ctx,
51     const Tensor& in,
52     optional<ArrayRef<int64_t>> dim_list,
53     bool keepdim,
54     Tensor& out) {
55   (void)ctx;
56 
57   ET_KERNEL_CHECK(
58       ctx,
59       check_reduction_args(in, dim_list, keepdim, {}, out),
60       InvalidArgument,
61       out);
62 
63   if (dim_list.has_value() && dim_list.value().empty()) {
64     ET_KERNEL_CHECK(
65         ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
66   } else {
67     ET_KERNEL_CHECK(
68         ctx,
69         resize_reduction_out(in, dim_list, keepdim, out) == Error::Ok,
70         InvalidArgument,
71         out);
72   }
73 
74   ET_KERNEL_CHECK(
75       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
76 
77   ScalarType in_type = in.scalar_type();
78   ScalarType out_type = out.scalar_type();
79   constexpr auto name = "any.dims_out";
80 
81   ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
82     ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
83       CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
84       if (dim_list.has_value() && dim_list.value().empty()) {
85         const CTYPE_IN* in_data = in.const_data_ptr<CTYPE_IN>();
86         for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
87           out_data[out_ix] =
88               static_cast<CTYPE_OUT>(static_cast<bool>(in_data[out_ix]));
89         }
90       } else {
91         for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
92           bool any = false;
93           if (in.numel() > 0) {
94             any = map_reduce_over_dim_list<CTYPE_IN, bool>(
95                 [](CTYPE_IN v) { return static_cast<bool>(v); },
96                 [](bool outv, bool acc) { return acc || outv; },
97                 in,
98                 dim_list,
99                 out_ix);
100           }
101           out_data[out_ix] = static_cast<CTYPE_OUT>(any);
102         }
103       }
104     });
105   });
106 
107   return out;
108 }
109 
any_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,bool keepdim,Tensor & out)110 Tensor& any_out(
111     KernelRuntimeContext& ctx,
112     const Tensor& in,
113     int64_t dim,
114     bool keepdim,
115     Tensor& out) {
116   (void)ctx;
117 
118   ET_KERNEL_CHECK(
119       ctx,
120       check_reduction_args_single_dim(
121           in, dim, keepdim, {}, out, /*allow_empty_dim*/ true),
122       InvalidArgument,
123       out);
124 
125   ET_KERNEL_CHECK(
126       ctx,
127       resize_reduction_out(in, dim, keepdim, out) == Error::Ok,
128       InvalidArgument,
129       out);
130 
131   ET_KERNEL_CHECK(
132       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
133 
134   ScalarType in_type = in.scalar_type();
135   ScalarType out_type = out.scalar_type();
136   constexpr auto name = "any.out";
137 
138   ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
139     ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
140       CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
141       for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
142         CTYPE_OUT any = false;
143         if (in.numel() > 0) {
144           std::tuple<CTYPE_OUT, long> acc =
145               map_reduce_over_dim<CTYPE_IN, CTYPE_OUT>(
146                   [](CTYPE_IN v) { return static_cast<bool>(v); },
147                   [](bool outv, long, bool acc, long) {
148                     return std::tuple<bool, long>{acc || outv, 0};
149                   },
150                   in,
151                   dim,
152                   out_ix);
153           any = std::get<0>(acc);
154         }
155         out_data[out_ix] = any;
156       }
157     });
158   });
159 
160   return out;
161 }
162 
163 } // namespace native
164 } // namespace executor
165 } // namespace torch
166