1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/kernel/kernel_includes.h>
10
11 namespace custom {
12 namespace native {
13
14 using exec_aten::ScalarType;
15 using exec_aten::Tensor;
16 using executorch::runtime::KernelRuntimeContext;
17
18 namespace {
check_preconditions(const Tensor & in,Tensor & out)19 void check_preconditions(const Tensor& in, Tensor& out) {
20 ET_CHECK_MSG(
21 out.scalar_type() == ScalarType::Float,
22 "Expected out tensor to have dtype Float, but got %hhd instead",
23 static_cast<int8_t>(out.scalar_type()));
24 ET_CHECK_MSG(
25 in.scalar_type() == ScalarType::Float,
26 "Expected in tensor to have dtype Float, but got %hhd instead",
27 static_cast<int8_t>(in.scalar_type()));
28 ET_CHECK_MSG(
29 out.dim() == in.dim(),
30 "Number of dims of out tensor is not compatible with inputs");
31 ET_CHECK_MSG(
32 out.numel() == in.numel(),
33 "Number of elements of out tensor %zd is not compatible with inputs %zd",
34 ssize_t(out.numel()),
35 ssize_t(in.numel()));
36 }
37 } // namespace
38
39 // mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
40 // ExecuTorch-compatible function signature, with a KernelRuntimeContext.
mul3_out_impl(ET_UNUSED KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)41 Tensor& mul3_out_impl(
42 ET_UNUSED KernelRuntimeContext& ctx,
43 const Tensor& in,
44 Tensor& out) {
45 check_preconditions(in, out);
46 float* out_data = out.mutable_data_ptr<float>();
47 const float* in_data = in.const_data_ptr<float>();
48 for (size_t out_idx = 0; out_idx < out.numel(); ++out_idx) {
49 out_data[out_idx] = in_data[out_idx] * 3;
50 }
51 return out;
52 }
53
54 } // namespace native
55 } // namespace custom
56