xref: /aosp_15_r20/external/executorch/extension/runner_util/inputs_portable.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/runner_util/inputs.h>
10 
11 #include <algorithm>
12 
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <executorch/runtime/executor/method.h>
16 #include <executorch/runtime/executor/method_meta.h>
17 #include <executorch/runtime/platform/log.h>
18 
19 using exec_aten::Tensor;
20 using exec_aten::TensorImpl;
21 using executorch::runtime::Error;
22 using executorch::runtime::Method;
23 using executorch::runtime::TensorInfo;
24 
25 namespace executorch {
26 namespace extension {
27 namespace internal {
28 
29 namespace {
30 /**
31  * Sets all elements of a tensor to 1.
32  */
fill_ones(torch::executor::Tensor tensor)33 Error fill_ones(torch::executor::Tensor tensor) {
34 #define FILL_CASE(T, n)                                \
35   case (torch::executor::ScalarType::n):               \
36     std::fill(                                         \
37         tensor.mutable_data_ptr<T>(),                  \
38         tensor.mutable_data_ptr<T>() + tensor.numel(), \
39         1);                                            \
40     break;
41 
42   switch (tensor.scalar_type()) {
43     ET_FORALL_REAL_TYPES_AND(Bool, FILL_CASE)
44     default:
45       ET_LOG(Error, "Unsupported scalar type %d", (int)tensor.scalar_type());
46       return Error::InvalidArgument;
47   }
48 
49 #undef FILL_CASE
50 
51   return Error::Ok;
52 }
53 } // namespace
54 
fill_and_set_input(Method & method,TensorInfo & tensor_meta,size_t input_index,void * data_ptr)55 Error fill_and_set_input(
56     Method& method,
57     TensorInfo& tensor_meta,
58     size_t input_index,
59     void* data_ptr) {
60   TensorImpl impl = TensorImpl(
61       tensor_meta.scalar_type(),
62       /*dim=*/tensor_meta.sizes().size(),
63       // These const pointers will not be modified because we never resize this
64       // short-lived TensorImpl. It only exists so that set_input() can verify
65       // that the shape is correct; the Method manages its own sizes and
66       // dim_order arrays for the input.
67       const_cast<TensorImpl::SizesType*>(tensor_meta.sizes().data()),
68       data_ptr,
69       const_cast<TensorImpl::DimOrderType*>(tensor_meta.dim_order().data()));
70   Tensor t(&impl);
71   ET_CHECK_OK_OR_RETURN_ERROR(fill_ones(t));
72   return method.set_input(t, input_index);
73 }
74 
75 } // namespace internal
76 } // namespace extension
77 } // namespace executorch
78