xref: /aosp_15_r20/external/executorch/runtime/executor/method_meta.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/core/error.h>
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12 #include <executorch/runtime/core/result.h>
13 #include <executorch/runtime/core/span.h>
14 #include <executorch/runtime/core/tag.h>
15 #include <executorch/runtime/executor/method_meta.h>
16 #include <executorch/schema/program_generated.h>
17 
18 namespace executorch {
19 namespace runtime {
20 
21 namespace {
get_tag(flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::EValue>>::return_type serialization_value,size_t index)22 Result<Tag> get_tag(
23     flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::EValue>>::
24         return_type serialization_value,
25     size_t index) {
26   switch (serialization_value->val_type()) {
27     case executorch_flatbuffer::KernelTypes::Null: {
28       return Tag::None;
29     } break;
30     case executorch_flatbuffer::KernelTypes::Int: {
31       return Tag::Int;
32     } break;
33     case executorch_flatbuffer::KernelTypes::Double: {
34       return Tag::Double;
35     } break;
36     case executorch_flatbuffer::KernelTypes::Bool: {
37       return Tag::Bool;
38     } break;
39     case executorch_flatbuffer::KernelTypes::String: {
40       return Tag::String;
41     } break;
42     case executorch_flatbuffer::KernelTypes::Tensor: {
43       return Tag::Tensor;
44     } break;
45     default:
46       ET_LOG(
47           Error,
48           "Invalid tag: %zu input idx: %zu",
49           (size_t)serialization_value->val_type(),
50           index);
51       return Error::Internal;
52   }
53 }
54 
calculate_nbytes(Span<const int32_t> sizes,exec_aten::ScalarType scalar_type)55 size_t calculate_nbytes(
56     Span<const int32_t> sizes,
57     exec_aten::ScalarType scalar_type) {
58   ssize_t n = 1;
59   for (ssize_t i = 0; i < sizes.size(); i++) {
60     n *= sizes[i];
61   }
62   // Use the full namespace to disambiguate from c10::elementSize.
63   return n * executorch::runtime::elementSize(scalar_type);
64 }
65 
66 } // namespace
67 
TensorInfo(Span<const int32_t> sizes,Span<const uint8_t> dim_order,exec_aten::ScalarType scalar_type,const bool is_memory_planned)68 TensorInfo::TensorInfo(
69     Span<const int32_t> sizes,
70     Span<const uint8_t> dim_order,
71     exec_aten::ScalarType scalar_type,
72     const bool is_memory_planned)
73     : sizes_(sizes),
74       dim_order_(dim_order),
75       scalar_type_(scalar_type),
76       is_memory_planned_(is_memory_planned),
77       nbytes_(calculate_nbytes(sizes_, scalar_type_)) {}
78 
sizes() const79 Span<const int32_t> TensorInfo::sizes() const {
80   return sizes_;
81 }
82 
dim_order() const83 Span<const uint8_t> TensorInfo::dim_order() const {
84   return dim_order_;
85 }
86 
scalar_type() const87 exec_aten::ScalarType TensorInfo::scalar_type() const {
88   return scalar_type_;
89 }
90 
is_memory_planned() const91 bool TensorInfo::is_memory_planned() const {
92   return is_memory_planned_;
93 }
94 
nbytes() const95 size_t TensorInfo::nbytes() const {
96   return nbytes_;
97 }
98 
MethodMeta(const executorch_flatbuffer::ExecutionPlan * s_plan)99 MethodMeta::MethodMeta(const executorch_flatbuffer::ExecutionPlan* s_plan)
100     : s_plan_(s_plan) {}
101 
name() const102 const char* MethodMeta::name() const {
103   return s_plan_->name()->c_str();
104 }
105 
num_inputs() const106 size_t MethodMeta::num_inputs() const {
107   return s_plan_->inputs()->size();
108 }
109 
input_tag(size_t index) const110 Result<Tag> MethodMeta::input_tag(size_t index) const {
111   auto num_inputs = this->num_inputs();
112   ET_CHECK_OR_RETURN_ERROR(
113       index >= 0 && index < num_inputs,
114       InvalidArgument,
115       "index %zu out of range. num_inputs: %zu",
116       index,
117       num_inputs);
118   auto input_index = s_plan_->inputs()->Get(index);
119   auto serialization_value = s_plan_->values()->Get(input_index);
120   return get_tag(serialization_value, index);
121 }
122 
input_tensor_meta(size_t index) const123 Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const {
124   auto tag = this->input_tag(index);
125   if (!tag.ok()) {
126     return tag.error();
127   }
128   ET_CHECK_OR_RETURN_ERROR(
129       tag.get() == Tag::Tensor,
130       InvalidArgument,
131       "Tag: %zu input: %zu is not Tensor",
132       (size_t)tag.get(),
133       index);
134   auto input_index = s_plan_->inputs()->Get(index);
135   auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor();
136   return TensorInfo(
137       Span<const int32_t>(
138           tensor_value->sizes()->data(), tensor_value->sizes()->size()),
139       Span<const uint8_t>(
140           tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
141       static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()),
142       tensor_value->allocation_info() != nullptr ||
143           tensor_value->data_buffer_idx() !=
144               0); // Count constant returns as memory planned.
145 }
146 
num_outputs() const147 size_t MethodMeta::num_outputs() const {
148   return s_plan_->outputs()->size();
149 }
150 
output_tag(size_t index) const151 Result<Tag> MethodMeta::output_tag(size_t index) const {
152   auto num_outputs = this->num_outputs();
153   ET_CHECK_OR_RETURN_ERROR(
154       index >= 0 && index < num_outputs,
155       InvalidArgument,
156       "index %zu out of range. num_outputs: %zu",
157       index,
158       num_outputs);
159   auto input_index = s_plan_->outputs()->Get(index);
160   auto serialization_value = s_plan_->values()->Get(input_index);
161   return get_tag(serialization_value, index);
162 }
163 
output_tensor_meta(size_t index) const164 Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const {
165   auto tag = this->output_tag(index);
166   if (!tag.ok()) {
167     return tag.error();
168   }
169   ET_CHECK_OR_RETURN_ERROR(
170       tag.get() == Tag::Tensor,
171       InvalidArgument,
172       "Tag: %zu output: %zu is not Tensor",
173       (size_t)tag.get(),
174       index);
175   auto output_index = s_plan_->outputs()->Get(index);
176   auto tensor_value = s_plan_->values()->Get(output_index)->val_as_Tensor();
177 
178   return TensorInfo(
179       Span<const int32_t>(
180           tensor_value->sizes()->data(), tensor_value->sizes()->size()),
181       Span<const uint8_t>(
182           tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
183       static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()),
184       tensor_value->allocation_info() != nullptr ||
185           tensor_value->data_buffer_idx() !=
186               0); // Count constant returns as memory planned.
187 }
188 
num_memory_planned_buffers() const189 size_t MethodMeta::num_memory_planned_buffers() const {
190   if (s_plan_->non_const_buffer_sizes() == nullptr) {
191     return 0;
192   }
193   const size_t size = s_plan_->non_const_buffer_sizes()->size();
194   // Index zero is reserved internally, and we hide it from users. The actual
195   // number of buffers is one fewer than the actual size of this list in the
196   // program.
197   return size > 0 ? size - 1 : 0;
198 }
199 
memory_planned_buffer_size(size_t index) const200 Result<int64_t> MethodMeta::memory_planned_buffer_size(size_t index) const {
201   auto num_buffers = this->num_memory_planned_buffers();
202   ET_CHECK_OR_RETURN_ERROR(
203       index >= 0 && index < num_buffers,
204       InvalidArgument,
205       "index %zu out of range. num_buffers: %zu",
206       index,
207       num_buffers);
208   // Index zero is reserved internally, and we hide it from users. Adjust the
209   // provided index to point to one of the actual buffers.
210   return s_plan_->non_const_buffer_sizes()->Get(index + 1);
211 }
212 
213 } // namespace runtime
214 } // namespace executorch
215