1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 #include <executorch/runtime/platform/assert.h>
13
14 namespace torch {
15 namespace executor {
16 namespace native {
17
mul_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)18 Tensor& mul_out(
19 KernelRuntimeContext& ctx,
20 const Tensor& a,
21 const Tensor& b,
22 Tensor& out) {
23 // Common Dtype
24 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
25
26 // Check Common Dtype
27 ET_KERNEL_CHECK(
28 ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
29
30 // Check Dim Order
31 ET_KERNEL_CHECK(
32 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
33
34 // Resize
35 ET_KERNEL_CHECK(
36 ctx,
37 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
38 InvalidArgument,
39 out);
40
41 // Compute Dtype
42 ScalarType compute_type = utils::get_compute_type(common_type);
43
44 // @lint-ignore CLANGTIDY facebook-hte-CArray
45 static constexpr const char op_name[] = "mul.out";
46
47 ET_KERNEL_CHECK(
48 ctx,
49 (executorch::runtime::isRealType(compute_type) ||
50 compute_type == ScalarType::Bool),
51 InvalidArgument,
52 out);
53
54 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
55 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
56 [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
57 return val_a * val_b;
58 },
59 ctx,
60 a,
61 utils::SupportedTensorDtypes::REALHBBF16,
62 b,
63 utils::SupportedTensorDtypes::REALHBBF16,
64 out,
65 utils::SupportedTensorDtypes::REALHBBF16);
66 });
67
68 return out;
69 }
70
mul_scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)71 Tensor& mul_scalar_out(
72 KernelRuntimeContext& ctx,
73 const Tensor& a,
74 const Scalar& b,
75 Tensor& out) {
76 // Common Dtype
77 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
78
79 // Check Common Dtype
80 ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);
81
82 // Check Dim Order
83 ET_KERNEL_CHECK(
84 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
85
86 // Resize
87 ET_KERNEL_CHECK(
88 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
89
90 // Compute Dtype
91 ScalarType compute_type = utils::get_compute_type(common_type);
92
93 // @lint-ignore CLANGTIDY facebook-hte-CArray
94 static constexpr const char op_name[] = "mul.Scalar_out";
95
96 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
97 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
98 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
99 [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; },
100 ctx,
101 a,
102 utils::SupportedTensorDtypes::REALHBBF16,
103 out,
104 utils::SupportedTensorDtypes::SAME_AS_COMMON);
105 });
106
107 return out;
108 }
109
110 } // namespace native
111 } // namespace executor
112 } // namespace torch
113