xref: /aosp_15_r20/external/executorch/extension/aten_util/aten_bridge.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/aten_util/aten_bridge.h>
10 
11 #include <executorch/runtime/platform/assert.h>
12 #include <cstring>
13 
14 namespace executorch {
15 namespace extension {
16 
17 namespace {
check_tensor_meta(const at::Tensor & a,const exec_aten::Tensor & b)18 void check_tensor_meta(const at::Tensor& a, const exec_aten::Tensor& b) {
19   // Check sizes/strides pointers
20   ET_CHECK_MSG(
21       b.sizes().data() != nullptr, "ETensor must have valid sizes array");
22   ET_CHECK_MSG(
23       b.strides().data() != nullptr, "ETensor must have valid strides array");
24   // Check disabled because in ASR model we get 1 element tensor with different
25   // rank.
26   /*
27 ET_CHECK_MSG(
28   a.dim() == b.dim(),
29   "at::Tensor and ETensor must have same rank."
30   " at::Tensor rank %zd, ETensor rank %zd.",
31   a.dim(),
32   b.dim());
33   */
34   // check sizes
35   for (size_t i = 0, dims = a.dim(); i < dims; ++i) {
36     ET_CHECK_MSG(
37         a.size(i) == b.size(i),
38         "Sizes dont match at index:%zd, a size %zd != b size %zd",
39         i,
40         ssize_t(a.size(i)),
41         ssize_t(b.size(i)));
42   }
43   // check strides
44   for (size_t i = 0, dims = a.dim(); i < dims; ++i) {
45     // Dont match strides if the size is 1.
46     // Why? Because tensor is non-contig only if
47     // strides dont match product(sizes[i:]) when size(i) > 1
48     // Strong assumption that must be tested and validated.
49     ET_CHECK_MSG(
50         (a.size(i) == 1 || (a.stride(i) == b.strides()[i])),
51         "Strides dont match at index:%zd, a stride %zd != b stride %zd",
52         i,
53         ssize_t(a.stride(i)),
54         ssize_t(b.strides()[i]));
55   }
56   // check dtype
57   ET_CHECK_MSG(
58       b.scalar_type() == torch_to_executorch_scalar_type(a.options().dtype()),
59       "dtypes dont match a %hhd vs. b %hhd",
60       static_cast<int8_t>(torch_to_executorch_scalar_type(a.options().dtype())),
61       static_cast<int8_t>(b.scalar_type()));
62 }
63 } // namespace
64 
torch_to_executorch_scalar_type(caffe2::TypeMeta type)65 executorch::runtime::etensor::ScalarType torch_to_executorch_scalar_type(
66     caffe2::TypeMeta type) {
67   const auto intermediate =
68       static_cast<std::underlying_type<c10::ScalarType>::type>(
69           c10::typeMetaToScalarType(type));
70 
71   ET_CHECK_MSG(
72       intermediate >= 0 &&
73           intermediate <= static_cast<std::underlying_type<
74                               executorch::runtime::etensor::ScalarType>::type>(
75                               executorch::runtime::etensor::ScalarType::UInt64),
76       "ScalarType %d unsupported in Executorch",
77       intermediate);
78   return static_cast<executorch::runtime::etensor::ScalarType>(intermediate);
79 }
80 
executorch_to_torch_scalar_type(torch::executor::ScalarType type)81 c10::ScalarType executorch_to_torch_scalar_type(
82     torch::executor::ScalarType type) {
83   const auto intermediate = static_cast<
84       std::underlying_type<executorch::runtime::etensor::ScalarType>::type>(
85       type);
86 
87   ET_CHECK_MSG(
88       intermediate >= 0 &&
89           intermediate <= static_cast<std::underlying_type<
90                               executorch::runtime::etensor::ScalarType>::type>(
91                               executorch::runtime::etensor::ScalarType::UInt64),
92       "ScalarType %d unsupported in Executorch",
93       intermediate);
94   return static_cast<c10::ScalarType>(intermediate);
95 }
96 
97 /*
98  * Following makes two assumptions:
99  * 1. aten_tensor's lifetime is longer than the liftime within which mutable_et
100  * is consumed
101  * 2. memory previously allocated to mutable_et, is leaked. However under the
102  * assumption , a strong one, that, such memory is arena allocated whose
103  * lifetime is tied to model's lifetime, we assume that memory is not leaked as
104  * it is freed when arean is freed.
105  * @param[in] aten_tensor Input at::Tensor
106  * @param[in/out] mutable_et ETensor whose underlying memory now will alias to
107  * aten_tensor
108  */
alias_etensor_to_attensor(at::Tensor & aten_tensor,torch::executor::Tensor & mutable_et)109 void alias_etensor_to_attensor(
110     at::Tensor& aten_tensor,
111     torch::executor::Tensor& mutable_et) {
112   // TODO(kimishpatel): contiguous according to memformat
113   // Right now we assume everything is channels first contiguous
114   // Note that input tensor must be contiguous for us to alias.
115   // Mixing aliasing and copying is dangerous since if we aliased
116   // the instance of mutatble_et to aten_tensor in the previous call,
117   // then in the next call copying will not be the correct behavior.
118   ET_CHECK_MSG(aten_tensor.is_contiguous(), "Input tensor must be contiguous");
119   check_tensor_meta(aten_tensor, mutable_et);
120   mutable_et.unsafeGetTensorImpl()->set_data(aten_tensor.mutable_data_ptr());
121 }
122 
alias_attensor_to_etensor(const torch::executor::Tensor & etensor)123 at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) {
124   c10::ScalarType dtype =
125       executorch_to_torch_scalar_type(etensor.scalar_type());
126   std::vector<int64_t> at_tensor_sizes(
127       etensor.sizes().begin(), etensor.sizes().end());
128   std::vector<int64_t> at_tensor_strides(
129       etensor.strides().begin(), etensor.strides().end());
130 
131   at::Tensor t = at::from_blob(
132       etensor.mutable_data_ptr(),
133       at_tensor_sizes,
134       at_tensor_strides,
135       at::TensorOptions(dtype));
136 
137   check_tensor_meta(t, etensor);
138   return t;
139 }
140 
alias_tensor_ptr_to_attensor(at::Tensor & t)141 TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t) {
142   return make_tensor_ptr(
143       {t.sizes().begin(), t.sizes().end()},
144       t.mutable_data_ptr(),
145       torch::executor::ScalarType(t.scalar_type()));
146 }
147 
148 } // namespace extension
149 } // namespace executorch
150