xref: /aosp_15_r20/external/executorch/examples/portable/custom_ops/custom_ops_1_out.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 
11 namespace custom {
12 namespace native {
13 
14 using exec_aten::ScalarType;
15 using exec_aten::Tensor;
16 using executorch::runtime::KernelRuntimeContext;
17 
18 namespace {
check_preconditions(const Tensor & in,Tensor & out)19 void check_preconditions(const Tensor& in, Tensor& out) {
20   ET_CHECK_MSG(
21       out.scalar_type() == ScalarType::Float,
22       "Expected out tensor to have dtype Float, but got %hhd instead",
23       static_cast<int8_t>(out.scalar_type()));
24   ET_CHECK_MSG(
25       in.scalar_type() == ScalarType::Float,
26       "Expected in tensor to have dtype Float, but got %hhd instead",
27       static_cast<int8_t>(in.scalar_type()));
28   ET_CHECK_MSG(
29       out.dim() == in.dim(),
30       "Number of dims of out tensor is not compatible with inputs");
31   ET_CHECK_MSG(
32       out.numel() == in.numel(),
33       "Number of elements of out tensor %zd is not compatible with inputs %zd",
34       ssize_t(out.numel()),
35       ssize_t(in.numel()));
36 }
37 } // namespace
38 
39 // mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)
40 // ExecuTorch-compatible function signature, with a KernelRuntimeContext.
mul3_out_impl(ET_UNUSED KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)41 Tensor& mul3_out_impl(
42     ET_UNUSED KernelRuntimeContext& ctx,
43     const Tensor& in,
44     Tensor& out) {
45   check_preconditions(in, out);
46   float* out_data = out.mutable_data_ptr<float>();
47   const float* in_data = in.const_data_ptr<float>();
48   for (size_t out_idx = 0; out_idx < out.numel(); ++out_idx) {
49     out_data[out_idx] = in_data[out_idx] * 3;
50   }
51   return out;
52 }
53 
54 } // namespace native
55 } // namespace custom
56