xref: /aosp_15_r20/external/executorch/kernels/optimized/cpu/op_div.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/optimized/cpu/binary_ops.h>
10 #include <executorch/kernels/optimized/vec/functional.h>
11 #include <executorch/kernels/optimized/vec/vec.h>
12 #include <executorch/kernels/portable/cpu/scalar_utils.h>
13 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 #include <executorch/runtime/platform/assert.h>
16 
17 namespace torch {
18 namespace executor {
19 namespace native {
20 
21 namespace {
22 
get_compute_type(ScalarType a_type,ScalarType b_type)23 ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) {
24   ET_CHECK(
25       !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type));
26   ET_CHECK(
27       !isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type));
28 
29   if (isFloatingType(a_type) && isFloatingType(b_type)) {
30     return promoteTypes(a_type, b_type);
31   } else if (isFloatingType(a_type)) {
32     return a_type;
33   } else if (isFloatingType(b_type)) {
34     return b_type;
35   }
36   return ScalarType::Float;
37 }
38 
39 } // namespace
40 
opt_div_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)41 Tensor& opt_div_out(
42     KernelRuntimeContext& ctx,
43     const Tensor& a,
44     const Tensor& b,
45     Tensor& out) {
46   (void)ctx;
47 
48   ScalarType a_type = a.scalar_type();
49   ScalarType b_type = b.scalar_type();
50   ScalarType out_type = out.scalar_type();
51 
52   if (a.numel() == 1 || b.numel() == 1) {
53     if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
54       const Tensor* tensor;
55       const Tensor* scalar;
56       ScalarType tensor_type;
57       ScalarType scalar_type;
58       if (a.numel() == 1) {
59         tensor = &b;
60         tensor_type = b_type;
61         scalar = &a;
62         scalar_type = a_type;
63       } else {
64         tensor = &a;
65         tensor_type = a_type;
66         scalar = &b;
67         scalar_type = b_type;
68       }
69       ET_KERNEL_CHECK(
70           ctx,
71           resize_to_broadcast_target_size(a, b, out) == Error::Ok,
72           InvalidArgument,
73           out);
74       ET_SWITCH_REALB_TYPES(tensor_type, ctx, "div.out", CTYPE, [&]() {
75         ET_SWITCH_REALB_TYPES(scalar_type, ctx, "div.out", CTYPE_SCALAR, [&]() {
76           CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();
77           CTYPE scalar_casted = static_cast<CTYPE>(scalar_val);
78 
79           using Vec = executorch::vec::Vectorized<CTYPE>;
80           if (a.numel() == 1) {
81             executorch::vec::map<CTYPE>(
82                 [scalar_casted](Vec x) { return Vec(scalar_casted) / x; },
83                 out.mutable_data_ptr<CTYPE>(),
84                 tensor->const_data_ptr<CTYPE>(),
85                 out.numel());
86           } else {
87             Vec inv_scalar_casted_vec(CTYPE(1) / scalar_casted);
88             executorch::vec::map<CTYPE>(
89                 [inv_scalar_casted_vec](Vec x) {
90                   return x * inv_scalar_casted_vec;
91                 },
92                 out.mutable_data_ptr<CTYPE>(),
93                 tensor->const_data_ptr<CTYPE>(),
94                 out.numel());
95           }
96         });
97       });
98       return out;
99     }
100   }
101 
102   auto selected_optimized_path = select_optimized_path(a, b, out);
103   if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
104     // Resize for dynamic shape
105     auto error = resize_tensor(out, a.sizes());
106     ET_KERNEL_CHECK_MSG(
107         ctx,
108         error == Error::Ok,
109         InvalidArgument,
110         out,
111         "Failed to resize output tensor.");
112 
113     ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "div.out", CTYPE, [&]() {
114       using Vec = executorch::vec::Vectorized<CTYPE>;
115       executorch::vec::map2<CTYPE>(
116           [](Vec x, Vec y) { return x / y; },
117           out.mutable_data_ptr<CTYPE>(),
118           a.const_data_ptr<CTYPE>(),
119           b.const_data_ptr<CTYPE>(),
120           out.numel());
121     });
122   } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
123     const Tensor* lhs;
124     const Tensor* rhs;
125     if (selected_optimized_path ==
126         ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
127       lhs = &b;
128       rhs = &a;
129     } else {
130       // Catch failure to update logic when subing new broadcasting possibility.
131       ET_DCHECK(
132           selected_optimized_path ==
133           ElementwiseOptimizedPath::kBroadcast2dBy1d);
134       lhs = &a;
135       rhs = &b;
136     }
137     auto error = resize_tensor(out, lhs->sizes());
138     ET_KERNEL_CHECK_MSG(
139         ctx,
140         error == Error::Ok,
141         InvalidArgument,
142         out,
143         "Failed to resize output tensor.");
144     ET_SWITCH_REALB_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
145       using Vec = executorch::vec::Vectorized<CTYPE>;
146       if (selected_optimized_path ==
147           ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
148         executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
149             [](Vec x, Vec y) { return y / x; },
150             out.mutable_data_ptr<CTYPE>(),
151             lhs->const_data_ptr<CTYPE>(),
152             rhs->const_data_ptr<CTYPE>(),
153             lhs->sizes()[lhs->dim() - 2],
154             lhs->sizes()[lhs->dim() - 1]);
155       } else {
156         executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
157             [](Vec x, Vec y) { return x / y; },
158             out.mutable_data_ptr<CTYPE>(),
159             lhs->const_data_ptr<CTYPE>(),
160             rhs->const_data_ptr<CTYPE>(),
161             lhs->sizes()[lhs->dim() - 2],
162             lhs->sizes()[lhs->dim() - 1]);
163       }
164     });
165   } else {
166     ScalarType common_type = get_compute_type(a_type, b_type);
167     ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
168 
169     ET_KERNEL_CHECK(
170         ctx,
171         resize_to_broadcast_target_size(a, b, out) == Error::Ok,
172         InvalidArgument,
173         out);
174 
175     ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() {
176       ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() {
177         ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() {
178           ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() {
179             apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
180                 [](const CTYPE_A val_a, const CTYPE_B val_b) {
181                   CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
182                   CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
183                   CTYPE_IN value = a_casted / b_casted;
184 
185                   return static_cast<CTYPE_OUT>(value);
186                 },
187                 a,
188                 b,
189                 out);
190           });
191         });
192       });
193     });
194   }
195 
196   return out;
197 }
198 
opt_div_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)199 Tensor& opt_div_scalar_out(
200     KernelRuntimeContext& ctx,
201     const Tensor& a,
202     const Scalar& b,
203     Tensor& out) {
204   (void)ctx;
205 
206   ScalarType a_type = a.scalar_type();
207   ScalarType b_type = utils::get_scalar_dtype(b);
208   ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float;
209   ScalarType out_type = out.scalar_type();
210 
211   ET_CHECK(common_type == out_type);
212 
213   // Resize for dynamic shape
214   auto error = resize_tensor(out, a.sizes());
215   ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
216 
217   if (a_type == common_type && a_type == out_type) {
218     ET_SWITCH_REAL_TYPES(a_type, ctx, "div.Scalar_out", CTYPE, [&]() {
219       ET_SWITCH_REAL_TYPES_AND(
220           Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
221             CTYPE_B b_val;
222             ET_EXTRACT_SCALAR(b, b_val);
223             CTYPE b_casted = static_cast<CTYPE>(b_val);
224 
225             using Vec = executorch::vec::Vectorized<CTYPE>;
226             Vec inv_b_casted_vec(CTYPE(1) / b_casted);
227             executorch::vec::map<CTYPE>(
228                 [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; },
229                 out.mutable_data_ptr<CTYPE>(),
230                 a.const_data_ptr<CTYPE>(),
231                 out.numel());
232           });
233     });
234   } else {
235     ET_SWITCH_REAL_TYPES_AND(
236         Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() {
237           ET_SWITCH_REAL_TYPES_AND(
238               Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() {
239                 ET_SWITCH_REAL_TYPES(
240                     common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() {
241                       ET_SWITCH_REAL_TYPES(
242                           out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() {
243                             CTYPE_B b_val;
244                             ET_EXTRACT_SCALAR(b, b_val);
245                             CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
246                             CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted;
247 
248                             const size_t n = a.numel();
249                             const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
250                             CTYPE_OUT* out_data =
251                                 out.mutable_data_ptr<CTYPE_OUT>();
252                             for (auto i = 0; i < n; ++i) {
253                               out_data[i] = static_cast<CTYPE_OUT>(
254                                   static_cast<CTYPE_IN>(a_data[i]) *
255                                   inv_b_casted);
256                             }
257                           });
258                     });
259               });
260         });
261   }
262 
263   return out;
264 }
265 
266 } // namespace native
267 } // namespace executor
268 } // namespace torch
269