1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace torch {
13 namespace executor {
14 namespace native {
15
16 using Tensor = exec_aten::Tensor;
17 using ScalarType = exec_aten::ScalarType;
18
any_all_out(KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)19 Tensor& any_all_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
20 (void)ctx;
21
22 ET_KERNEL_CHECK(
23 ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
24
25 ET_KERNEL_CHECK(
26 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
27
28 ScalarType in_type = in.scalar_type();
29 ScalarType out_type = out.scalar_type();
30 constexpr auto name = "any.all_out";
31
32 ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
33 ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
34 const auto data_in = in.const_data_ptr<CTYPE_IN>();
35 auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
36 data_out[0] = static_cast<CTYPE_OUT>(false);
37 for (auto i = 0; i < in.numel(); ++i) {
38 if (static_cast<bool>(data_in[i])) {
39 data_out[0] = static_cast<CTYPE_OUT>(true);
40 break;
41 }
42 }
43 });
44 });
45
46 return out;
47 }
48
any_dims_out(KernelRuntimeContext & ctx,const Tensor & in,optional<ArrayRef<int64_t>> dim_list,bool keepdim,Tensor & out)49 Tensor& any_dims_out(
50 KernelRuntimeContext& ctx,
51 const Tensor& in,
52 optional<ArrayRef<int64_t>> dim_list,
53 bool keepdim,
54 Tensor& out) {
55 (void)ctx;
56
57 ET_KERNEL_CHECK(
58 ctx,
59 check_reduction_args(in, dim_list, keepdim, {}, out),
60 InvalidArgument,
61 out);
62
63 if (dim_list.has_value() && dim_list.value().empty()) {
64 ET_KERNEL_CHECK(
65 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
66 } else {
67 ET_KERNEL_CHECK(
68 ctx,
69 resize_reduction_out(in, dim_list, keepdim, out) == Error::Ok,
70 InvalidArgument,
71 out);
72 }
73
74 ET_KERNEL_CHECK(
75 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
76
77 ScalarType in_type = in.scalar_type();
78 ScalarType out_type = out.scalar_type();
79 constexpr auto name = "any.dims_out";
80
81 ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
82 ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
83 CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
84 if (dim_list.has_value() && dim_list.value().empty()) {
85 const CTYPE_IN* in_data = in.const_data_ptr<CTYPE_IN>();
86 for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
87 out_data[out_ix] =
88 static_cast<CTYPE_OUT>(static_cast<bool>(in_data[out_ix]));
89 }
90 } else {
91 for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
92 bool any = false;
93 if (in.numel() > 0) {
94 any = map_reduce_over_dim_list<CTYPE_IN, bool>(
95 [](CTYPE_IN v) { return static_cast<bool>(v); },
96 [](bool outv, bool acc) { return acc || outv; },
97 in,
98 dim_list,
99 out_ix);
100 }
101 out_data[out_ix] = static_cast<CTYPE_OUT>(any);
102 }
103 }
104 });
105 });
106
107 return out;
108 }
109
any_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,bool keepdim,Tensor & out)110 Tensor& any_out(
111 KernelRuntimeContext& ctx,
112 const Tensor& in,
113 int64_t dim,
114 bool keepdim,
115 Tensor& out) {
116 (void)ctx;
117
118 ET_KERNEL_CHECK(
119 ctx,
120 check_reduction_args_single_dim(
121 in, dim, keepdim, {}, out, /*allow_empty_dim*/ true),
122 InvalidArgument,
123 out);
124
125 ET_KERNEL_CHECK(
126 ctx,
127 resize_reduction_out(in, dim, keepdim, out) == Error::Ok,
128 InvalidArgument,
129 out);
130
131 ET_KERNEL_CHECK(
132 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
133
134 ScalarType in_type = in.scalar_type();
135 ScalarType out_type = out.scalar_type();
136 constexpr auto name = "any.out";
137
138 ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
139 ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
140 CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
141 for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
142 CTYPE_OUT any = false;
143 if (in.numel() > 0) {
144 std::tuple<CTYPE_OUT, long> acc =
145 map_reduce_over_dim<CTYPE_IN, CTYPE_OUT>(
146 [](CTYPE_IN v) { return static_cast<bool>(v); },
147 [](bool outv, long, bool acc, long) {
148 return std::tuple<bool, long>{acc || outv, 0};
149 },
150 in,
151 dim,
152 out_ix);
153 any = std::get<0>(acc);
154 }
155 out_data[out_ix] = any;
156 }
157 });
158 });
159
160 return out;
161 }
162
163 } // namespace native
164 } // namespace executor
165 } // namespace torch
166