1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include <executorch/runtime/platform/assert.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
add_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,const Scalar & alpha,Tensor & out)19 Tensor& add_out(
20 KernelRuntimeContext& ctx,
21 const Tensor& a,
22 const Tensor& b,
23 const Scalar& alpha,
24 Tensor& out) {
25 // Common Dtype
26 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
27
28 // Check Common Dtype
29 ET_KERNEL_CHECK(
30 ctx,
31 (canCast(common_type, out.scalar_type()) &&
32 check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
33 InvalidArgument,
34 out);
35
36 // Check Dim Order
37 ET_KERNEL_CHECK(
38 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
39
40 // Resize
41 ET_KERNEL_CHECK(
42 ctx,
43 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
44 InvalidArgument,
45 out);
46
47 // Compute Dtype
48 ScalarType compute_type = utils::get_compute_type(common_type);
49
50 // @lint-ignore CLANGTIDY facebook-hte-CArray
51 static constexpr const char op_name[] = "add.out";
52
53 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
54 const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
55 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56 [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
57 return val_a + val_alpha * val_b;
58 },
59 ctx,
60 a,
61 utils::SupportedTensorDtypes::REALHBBF16,
62 b,
63 utils::SupportedTensorDtypes::REALHBBF16,
64 out,
65 utils::SupportedTensorDtypes::REALHBBF16);
66 });
67
68 return out;
69 }
70
add_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,const Scalar & alpha,Tensor & out)71 Tensor& add_scalar_out(
72 KernelRuntimeContext& ctx,
73 const Tensor& a,
74 const Scalar& b,
75 const Scalar& alpha,
76 Tensor& out) {
77 // Common Dtype
78 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
79
80 // Check Common Dtype
81 ET_KERNEL_CHECK(
82 ctx,
83 (common_type == out.scalar_type() &&
84 check_alpha_type(utils::get_scalar_dtype(alpha), common_type)),
85 InvalidArgument,
86 out);
87
88 // Check Dim Order
89 ET_KERNEL_CHECK(
90 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
91
92 // Resize
93 ET_KERNEL_CHECK(
94 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
95
96 // Compute Dtype
97 ScalarType compute_type = utils::get_compute_type(common_type);
98
99 // @lint-ignore CLANGTIDY facebook-hte-CArray
100 static constexpr const char op_name[] = "add.Scalar_out";
101
102 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
103 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
104 [b, alpha](const CTYPE_COMPUTE val_a) {
105 CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
106 CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
107 return val_a + val_alpha * val_b;
108 },
109 ctx,
110 a,
111 utils::SupportedTensorDtypes::REALHBBF16,
112 out,
113 utils::SupportedTensorDtypes::SAME_AS_COMMON);
114 });
115
116 return out;
117 }
118
119 } // namespace native
120 } // namespace executor
121 } // namespace torch
122