1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/kernel/kernel_includes.h>
10
11 namespace custom {
12 namespace native {
13
14 using exec_aten::ScalarType;
15 using exec_aten::Tensor;
16 using executorch::runtime::KernelRuntimeContext;
17
18 namespace {
check_preconditions(const Tensor & in,Tensor & out)19 void check_preconditions(const Tensor& in, Tensor& out) {
20 ET_CHECK_MSG(
21 out.scalar_type() == ScalarType::Float,
22 "Expected out tensor to have dtype Float, but got %d instead",
23 static_cast<int>(out.scalar_type()));
24 ET_CHECK_MSG(
25 in.scalar_type() == ScalarType::Float,
26 "Expected in tensor to have dtype Float, but got %d instead",
27 static_cast<int>(in.scalar_type()));
28 ET_CHECK_MSG(
29 out.dim() == in.dim(),
30 "Number of dims of out tensor is not compatible with inputs");
31 ET_CHECK_MSG(
32 out.numel() == in.numel(),
33 "Number of elements of out tensor %zd is not compatible with inputs %zd",
34 ssize_t(out.numel()),
35 ssize_t(in.numel()));
36 }
37 } // namespace
38
39 // mul4.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
40 // ATen-compatible function signature, without a KernelRuntimeContext.
mul4_out_impl(const Tensor & in,Tensor & out)41 Tensor& mul4_out_impl(const Tensor& in, Tensor& out) {
42 check_preconditions(in, out);
43 float* out_data = out.mutable_data_ptr<float>();
44 const float* in_data = in.const_data_ptr<float>();
45 for (size_t out_idx = 0; out_idx < out.numel(); ++out_idx) {
46 out_data[out_idx] = in_data[out_idx] * 4;
47 }
48 return out;
49 }
50
51 // mul4.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
52 // ExecuTorch-compatible function signature, with a KernelRuntimeContext.
mul4_out_impl(ET_UNUSED KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)53 Tensor& mul4_out_impl(
54 ET_UNUSED KernelRuntimeContext& ctx,
55 const Tensor& in,
56 Tensor& out) {
57 mul4_out_impl(in, out);
58 return out;
59 }
60
61 } // namespace native
62 } // namespace custom
63