xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_rsub.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 namespace native {
16 
rsub_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,const Scalar & alpha,Tensor & out)17 Tensor& rsub_scalar_out(
18     KernelRuntimeContext& ctx,
19     const Tensor& a,
20     const Scalar& b,
21     const Scalar& alpha,
22     Tensor& out) {
23   ScalarType alpha_type = utils::get_scalar_dtype(alpha);
24 
25   // Check alpha type
26   ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
27 
28   // Common Dtype
29   ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
30 
31   // Check Common Dtype
32   ET_KERNEL_CHECK(
33       ctx,
34       (common_type == out.scalar_type() && canCast(alpha_type, common_type)),
35       InvalidArgument,
36       out);
37 
38   // Check Dim Order
39   ET_KERNEL_CHECK(
40       ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
41 
42   // Resize
43   ET_KERNEL_CHECK(
44       ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
45 
46   // Compute Dtype
47   ScalarType compute_type = utils::get_compute_type(common_type);
48 
49   // @lint-ignore CLANGTIDY facebook-hte-CArray
50   static constexpr const char op_name[] = "rsub.Scalar_out";
51 
52   ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
53     const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
54     const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56         [val_b, val_alpha](const CTYPE_COMPUTE val_a) {
57           return val_b - val_alpha * val_a;
58         },
59         ctx,
60         a,
61         utils::SupportedTensorDtypes::REALHBF16,
62         out,
63         utils::SupportedTensorDtypes::SAME_AS_COMMON);
64   });
65 
66   return out;
67 }
68 
69 } // namespace native
70 } // namespace executor
71 } // namespace torch
72