1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/runner_util/inputs.h>
10
11 #include <algorithm>
12
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
15 #include <executorch/runtime/executor/method.h>
16 #include <executorch/runtime/executor/method_meta.h>
17 #include <executorch/runtime/platform/log.h>
18
19 using exec_aten::Tensor;
20 using exec_aten::TensorImpl;
21 using executorch::runtime::Error;
22 using executorch::runtime::Method;
23 using executorch::runtime::TensorInfo;
24
25 namespace executorch {
26 namespace extension {
27 namespace internal {
28
29 namespace {
30 /**
31 * Sets all elements of a tensor to 1.
32 */
fill_ones(torch::executor::Tensor tensor)33 Error fill_ones(torch::executor::Tensor tensor) {
34 #define FILL_CASE(T, n) \
35 case (torch::executor::ScalarType::n): \
36 std::fill( \
37 tensor.mutable_data_ptr<T>(), \
38 tensor.mutable_data_ptr<T>() + tensor.numel(), \
39 1); \
40 break;
41
42 switch (tensor.scalar_type()) {
43 ET_FORALL_REAL_TYPES_AND(Bool, FILL_CASE)
44 default:
45 ET_LOG(Error, "Unsupported scalar type %d", (int)tensor.scalar_type());
46 return Error::InvalidArgument;
47 }
48
49 #undef FILL_CASE
50
51 return Error::Ok;
52 }
53 } // namespace
54
fill_and_set_input(Method & method,TensorInfo & tensor_meta,size_t input_index,void * data_ptr)55 Error fill_and_set_input(
56 Method& method,
57 TensorInfo& tensor_meta,
58 size_t input_index,
59 void* data_ptr) {
60 TensorImpl impl = TensorImpl(
61 tensor_meta.scalar_type(),
62 /*dim=*/tensor_meta.sizes().size(),
63 // These const pointers will not be modified because we never resize this
64 // short-lived TensorImpl. It only exists so that set_input() can verify
65 // that the shape is correct; the Method manages its own sizes and
66 // dim_order arrays for the input.
67 const_cast<TensorImpl::SizesType*>(tensor_meta.sizes().data()),
68 data_ptr,
69 const_cast<TensorImpl::DimOrderType*>(tensor_meta.dim_order().data()));
70 Tensor t(&impl);
71 ET_CHECK_OK_OR_RETURN_ERROR(fill_ones(t));
72 return method.set_input(t, input_index);
73 }
74
75 } // namespace internal
76 } // namespace extension
77 } // namespace executorch
78