1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15 namespace native {
16
rsub_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,const Scalar & alpha,Tensor & out)17 Tensor& rsub_scalar_out(
18 KernelRuntimeContext& ctx,
19 const Tensor& a,
20 const Scalar& b,
21 const Scalar& alpha,
22 Tensor& out) {
23 ScalarType alpha_type = utils::get_scalar_dtype(alpha);
24
25 // Check alpha type
26 ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);
27
28 // Common Dtype
29 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
30
31 // Check Common Dtype
32 ET_KERNEL_CHECK(
33 ctx,
34 (common_type == out.scalar_type() && canCast(alpha_type, common_type)),
35 InvalidArgument,
36 out);
37
38 // Check Dim Order
39 ET_KERNEL_CHECK(
40 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
41
42 // Resize
43 ET_KERNEL_CHECK(
44 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
45
46 // Compute Dtype
47 ScalarType compute_type = utils::get_compute_type(common_type);
48
49 // @lint-ignore CLANGTIDY facebook-hte-CArray
50 static constexpr const char op_name[] = "rsub.Scalar_out";
51
52 ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
53 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
54 const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56 [val_b, val_alpha](const CTYPE_COMPUTE val_a) {
57 return val_b - val_alpha * val_a;
58 },
59 ctx,
60 a,
61 utils::SupportedTensorDtypes::REALHBF16,
62 out,
63 utils::SupportedTensorDtypes::SAME_AS_COMMON);
64 });
65
66 return out;
67 }
68
69 } // namespace native
70 } // namespace executor
71 } // namespace torch
72